From c7a9d89d85bc5bc7c9c826013515b532cc6ba949 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 13:24:01 +0100 Subject: [PATCH 001/281] Just const_fold2 + inside that partial_value (taken from hugr_core) --- hugr-passes/src/const_fold2.rs | 2 + hugr-passes/src/const_fold2/datalog.rs | 254 ++++++++++ .../src/const_fold2/datalog/context.rs | 67 +++ hugr-passes/src/const_fold2/datalog/test.rs | 232 +++++++++ hugr-passes/src/const_fold2/datalog/utils.rs | 390 +++++++++++++++ hugr-passes/src/const_fold2/partial_value.rs | 454 ++++++++++++++++++ .../src/const_fold2/partial_value/test.rs | 346 +++++++++++++ .../const_fold2/partial_value/value_handle.rs | 245 ++++++++++ hugr-passes/src/lib.rs | 1 + 9 files changed, 1991 insertions(+) create mode 100644 hugr-passes/src/const_fold2.rs create mode 100644 hugr-passes/src/const_fold2/datalog.rs create mode 100644 hugr-passes/src/const_fold2/datalog/context.rs create mode 100644 hugr-passes/src/const_fold2/datalog/test.rs create mode 100644 hugr-passes/src/const_fold2/datalog/utils.rs create mode 100644 hugr-passes/src/const_fold2/partial_value.rs create mode 100644 hugr-passes/src/const_fold2/partial_value/test.rs create mode 100644 hugr-passes/src/const_fold2/partial_value/value_handle.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs new file mode 100644 index 000000000..dbe4464fd --- /dev/null +++ b/hugr-passes/src/const_fold2.rs @@ -0,0 +1,2 @@ +mod datalog; +pub mod partial_value; \ No newline at end of file diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs new file mode 100644 index 000000000..d7df9c1e6 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -0,0 +1,254 @@ +use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; +use delegate::delegate; +use itertools::{zip_eq, Itertools}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; + +use either::Either; +use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::partial_value::{PartialValue, ValueHandle, ValueKey}; +use hugr_core::types::{EdgeKind, FunctionType, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +mod context; +mod utils; + +use context::DataflowContext; +pub use utils::{TailLoopTermination, ValueRow, IO, PV}; + +pub trait DFContext: AsRef + Clone + Eq + Hash + std::ops::Deref {} + +ascent::ascent! { + // The trait-indirection layer here means we can just write 'C' but in practice ATM + // DataflowContext (for H: HugrView) would be sufficient, there's really no + // point in using anything else yet. However DFContext will be useful when we + // move interpretation of nodes out into a trait method. + struct AscentProgram; + relation context(C); + relation out_wire_value_proto(Node, OutgoingPort, PV); + + relation node(C, Node); + relation in_wire(C, Node, IncomingPort); + relation out_wire(C, Node, OutgoingPort); + relation parent_of_node(C, Node, Node); + relation io_node(C, Node, Node, IO); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); + + node(c, n) <-- context(c), for n in c.nodes(); + + in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c, *n); + + out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c, *n); + + parent_of_node(c, parent, child) <-- + node(c, child), if let Some(parent) = c.get_parent(*child); + + io_node(c, parent, child, io) <-- node(c, parent), + if let Some([i,o]) = c.get_io(*parent), + for (child,io) in [(i,IO::Input),(o,IO::Output)]; + // We support prepopulating out_wire_value via out_wire_value_proto. + // + // out wires that do not have prepopulation values are initialised to bottom. + out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); + out_wire_value(c, n, p, v) <-- out_wire(c,n,p) , out_wire_value_proto(n, p, v); + + in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), + if let Some((m,op)) = c.single_linked_output(*n, *ip), + out_wire_value(c, m, op, v); + + + node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n); + node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + + + // Per node-type rules + // TODO do all leaf ops with a rule + // define `fn propagate_leaf_op(Context, Node, ValueRow) -> ValueRow + + // LoadConstant + relation load_constant_node(C, Node); + load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); + + out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <-- + load_constant_node(c, n); + + + // MakeTuple + relation make_tuple_node(C, Node); + make_tuple_node(c, n) <-- node(c, n), if c.get_optype(*n).is_make_tuple(); + + out_wire_value(c, n, 0.into(), utils::partial_value_tuple_from_value_row(vs.clone())) <-- + make_tuple_node(c, n), node_in_value_row(c, n, vs); + + + // UnpackTuple + relation unpack_tuple_node(C, Node); + unpack_tuple_node(c,n) <-- node(c, n), if c.get_optype(*n).is_unpack_tuple(); + + out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- + unpack_tuple_node(c, n), + in_wire_value(c, n, IncomingPort::from(0), v), + out_wire(c, n, p); + + + // DFG + relation dfg_node(C, Node); + dfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_dfg(); + + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), + io_node(c, dfg, i, IO::Input), in_wire_value(c, dfg, p, v); + + out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), + io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v); + + + // TailLoop + relation tail_loop_node(C, Node); + tail_loop_node(c,n) <-- node(c, n), if c.get_optype(*n).is_tail_loop(); + + // inputs of tail loop propagate to Input node of child region + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_node(c, tl), + io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); + + // Output node of child region propagate to Input node of child region + out_wire_value(c, in_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,in_n, IO::Input), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 + if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), + let variant_len = tailloop.just_inputs.len(), + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) + ); + + // Output node of child region propagate to outputs of tail loop + out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 + if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), + let variant_len = tailloop.just_outputs.len(), + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) + ); + + lattice tail_loop_termination(C,Node,TailLoopTermination); + tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- + tail_loop_node(c,tl_n); + tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <-- + tail_loop_node(c,tl_n), + io_node(c,tl,out_n, IO::Output), + in_wire_value(c, out_n, IncomingPort::from(0), v); + + + // Conditional + relation conditional_node(C, Node); + relation case_node(C,Node,usize, Node); + + conditional_node (c,n)<-- node(c, n), if c.get_optype(*n).is_conditional(); + case_node(c,cond,i, case) <-- conditional_node(c,cond), + for (i, case) in c.children(*cond).enumerate(), + if c.get_optype(case).is_case(); + + // inputs of conditional propagate into case nodes + out_wire_value(c, i_node, i_p, v) <-- + case_node(c, cond, case_index, case), + io_node(c, case, i_node, IO::Input), + in_wire_value(c, cond, cond_in_p, cond_in_v), + if let Some(conditional) = c.get_optype(*cond).as_conditional(), + let variant_len = conditional.sum_rows[*case_index].len(), + for (i_p, v) in utils::outputs_for_variant(*cond_in_p, *case_index, variant_len, cond_in_v); + + // outputs of case nodes propagate to outputs of conditional + out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(c, cond, _, case), + io_node(c, case, o, IO::Output), + in_wire_value(c, o, o_p, v); + + lattice case_reachable(C, Node, Node, bool); + case_reachable(c, cond, case, reachable) <-- case_node(c,cond,i,case), + in_wire_value(c, cond, IncomingPort::from(0), v), + let reachable = v.supports_tag(*i); + +} + +// TODO This should probably be called 'Analyser' or something +struct Machine( + AscentProgram>, + Option>, +); + +/// Usage: +/// 1. [Self::new()] +/// 2. Zero or more [Self::propolutate_out_wires] with initial values +/// 3. Exactly one [Self::run_hugr] to do the analysis +/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] +impl Machine { + pub fn new() -> Self { + Self(Default::default(), None) + } + + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + assert!(self.1.is_none()); + self.0.out_wire_value_proto.extend( + wires + .into_iter() + .map(|(w, v)| (w.node(), w.source(), v.into())), + ); + } + + pub fn run_hugr(&mut self, hugr: H) { + assert!(self.1.is_none()); + self.0.context.push((DataflowContext::new(hugr),)); + self.0.run(); + self.1 = Some( + self.0 + .out_wire_value + .iter() + .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone().into())) + .collect(), + ) + } + + pub fn read_out_wire_partial_value(&self, w: Wire) -> Option { + self.1.as_ref().unwrap().get(&w).cloned() + } + + pub fn read_out_wire_value(&self, hugr: H, w: Wire) -> Option { + // dbg!(&w); + let pv = self.read_out_wire_partial_value(w)?; + // dbg!(&pv); + let (_, typ) = hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + pv.try_into_value(&typ).ok() + } + + pub fn tail_loop_terminates(&self, hugr: H, node: Node) -> TailLoopTermination { + assert!(hugr.get_optype(node).is_tail_loop()); + self.0 + .tail_loop_termination + .iter() + .find_map(|(_, n, v)| (n == &node).then_some(*v)) + .unwrap() + } + + pub fn case_reachable(&self, hugr: H, case: Node) -> bool { + assert!(hugr.get_optype(case).is_case()); + let cond = hugr.get_parent(case).unwrap(); + assert!(hugr.get_optype(cond).is_conditional()); + self.0 + .case_reachable + .iter() + .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) + .unwrap() + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs new file mode 100644 index 000000000..92c0c3285 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -0,0 +1,67 @@ +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::atomic::AtomicUsize; +use std::sync::{Arc, Mutex}; + +use hugr_core::hugr::internal::HugrInternals; +use hugr_core::ops::Value; +use hugr_core::partial_value::{ValueHandle, ValueKey}; +use hugr_core::{Hugr, HugrView, Node}; + +use super::DFContext; + +#[derive(Debug)] +pub(super) struct DataflowContext(Arc); + +impl DataflowContext { + pub fn new(hugr: H) -> Self { + Self(Arc::new(hugr)) + } +} + +// Deriving Clone requires H:HugrView to implement Clone, +// but we don't need that as we only clone the Arc. +impl Clone for DataflowContext { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Hash for DataflowContext { + fn hash(&self, state: &mut I) {} +} + +impl PartialEq for DataflowContext { + fn eq(&self, other: &Self) -> bool { + // Any AscentProgram should have only one DataflowContext + assert_eq!(self as *const _, other as *const _); + true + } +} + +impl Eq for DataflowContext {} + +impl PartialOrd for DataflowContext { + fn partial_cmp(&self, other: &Self) -> Option { + // Any AscentProgram should have only one DataflowContext + assert_eq!(self as *const _, other as *const _); + Some(std::cmp::Ordering::Equal) + } +} + +impl Deref for DataflowContext { + type Target = Hugr; + + fn deref(&self) -> &Self::Target { + self.0.base_hugr() + } +} + +impl AsRef for DataflowContext { + fn as_ref(&self) -> &Hugr { + self.base_hugr() + } +} + +impl DFContext for DataflowContext {} diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs new file mode 100644 index 000000000..4e086c4b7 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -0,0 +1,232 @@ +use hugr_core::{ + builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, + extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, + ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, + type_row, + types::{FunctionType, SumType}, + Extension, +}; + +use hugr_core::partial_value::PartialValue; + +use super::*; + +#[test] +fn test_make_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + machine.run_hugr(&hugr); + + let x = machine.read_out_wire_value(&hugr, v3).unwrap(); + assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); +} + +#[test] +fn test_unpack_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let [o1, o2] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v3]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + machine.run_hugr(&hugr); + + let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); + assert_eq!(o1_r, Value::false_val()); + let o2_r = machine.read_out_wire_value(&hugr, o2).unwrap(); + assert_eq!(o2_r, Value::true_val()); +} + +#[test] +fn test_unpack_const() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); + let [o] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); + assert_eq!(o_r, Value::true_val()); +} + +#[test] +fn test_tail_loop_never_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_v = Value::unit_sum(3, 6).unwrap(); + let r_w = builder.add_load_value( + Value::sum( + 1, + [r_v.clone()], + SumType::new([type_row![], r_v.get_type().into()]), + ) + .unwrap(), + ); + let tlb = builder + .tail_loop_builder([], [], vec![r_v.get_type()].into()) + .unwrap(); + let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + let o_r = machine.read_out_wire_value(&hugr, tl_o).unwrap(); + assert_eq!(o_r, r_v); + assert_eq!( + TailLoopTermination::ExactlyZeroContinues, + machine.tail_loop_terminates(&hugr, tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_always_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_w = builder + .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let true_w = builder.add_load_value(Value::true_val()); + + let tlb = builder + .tail_loop_builder([], [(BOOL_T, true_w)], vec![BOOL_T].into()) + .unwrap(); + + // r_w has tag 0, so we always continue; + // we put true in our "other_output", but we should not propagate this to + // output because r_w never supports 1. + let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); + + let [tl_o1, tl_o2] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + machine.run_hugr(&hugr); + + let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); + assert_eq!(o_r1, PartialValue::bottom()); + let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); + assert_eq!(o_r2, PartialValue::bottom()); + assert_eq!( + TailLoopTermination::bottom(), + machine.tail_loop_terminates(&hugr, tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_iterates_twice() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + // let r_w = builder + // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let tlb = builder + .tail_loop_builder([], [(BOOL_T, false_w), (BOOL_T, true_w)], vec![].into()) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().dataflow_signature().unwrap(), + FunctionType::new_endo(type_row![BOOL_T, BOOL_T]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); + + // let optype = builder.hugr().get_optype(tail_loop.node()); + // for p in builder.hugr().node_outputs(tail_loop.node()) { + // use hugr_core::ops::OpType; + // println!("{:?}, {:?}", p, optype.port_kind(p)); + + // } + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + // TODO once we can do conditionals put these wires inside `just_outputs` and + // we should be able to propagate their values + let [o_w1, o_w2, _] = tail_loop.outputs_arr(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + // TODO these hould be the propagated values for now they will bt join(true,false) + let o_r1 = machine.read_out_wire_partial_value(o_w1).unwrap(); + // assert_eq!(o_r1, PartialValue::top()); + let o_r2 = machine.read_out_wire_partial_value(o_w2).unwrap(); + // assert_eq!(o_r2, Value::true_val()); + assert_eq!( + TailLoopTermination::Top, + machine.tail_loop_terminates(&hugr, tail_loop.node()) + ) +} + +#[test] +fn conditional() { + let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; + let cond_t = Type::new_sum(variants.clone()); + let mut builder = DFGBuilder::new(FunctionType::new( + Into::::into(cond_t), + type_row![], + )) + .unwrap(); + let [arg_w] = builder.input_wires_arr(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut cond_builder = builder + .conditional_builder( + (variants, arg_w), + [(BOOL_T, true_w)], + type_row!(BOOL_T, BOOL_T), + ExtensionSet::default(), + ) + .unwrap(); + // will be unreachable + let case1_b = cond_builder.case_builder(0).unwrap(); + let case1 = case1_b.finish_with_outputs([false_w, false_w]).unwrap(); + + let case2_b = cond_builder.case_builder(1).unwrap(); + let [c2a] = case2_b.input_wires_arr(); + let case2 = case2_b.finish_with_outputs([false_w, c2a]).unwrap(); + + let case3_b = cond_builder.case_builder(2).unwrap(); + let [c3_1, c3_2] = case3_b.input_wires_arr(); + let case3 = case3_b.finish_with_outputs([c3_1, false_w]).unwrap(); + + let cond = cond_builder.finish_sub_container().unwrap(); + + let [cond_o1, cond_o2] = cond.outputs_arr(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let arg_pv = + PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); + machine.propolutate_out_wires([(arg_w, arg_pv)]); + machine.run_hugr(&hugr); + + let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); + assert_eq!(cond_r1, Value::false_val()); + assert!(machine.read_out_wire_value(&hugr, cond_o2).is_none()); + + assert!(!machine.case_reachable(&hugr, case1.node())); + assert!(machine.case_reachable(&hugr, case2.node())); + assert!(machine.case_reachable(&hugr, case3.node())); +} diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs new file mode 100644 index 000000000..9c2e46ae3 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -0,0 +1,390 @@ +// proptest-derive generates many of these warnings. +// https://github.com/rust-lang/rust/issues/120363 +// https://github.com/proptest-rs/proptest/issues/447 +#![cfg_attr(test, allow(non_local_definitions))] + +use std::{cmp::Ordering, ops::Index, sync::Arc}; + +use ascent::lattice::{BoundedLattice, Lattice}; +use either::Either; +use hugr_core::{ + ops::OpTrait as _, + partial_value::{PartialValue, ValueHandle}, + types::{EdgeKind, TypeRow}, + HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, +}; +use itertools::zip_eq; + +#[cfg(test)] +use proptest_derive::Arbitrary; + +#[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] +pub struct PV(PartialValue); + +impl From for PV { + fn from(inner: PartialValue) -> Self { + Self(inner) + } +} + +impl PV { + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO the arguments here are not pretty, two usizes, better not mix them + /// up!!! + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + self.0.variant_field_value(variant, idx).into() + } + + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.supports_tag(tag) + } +} + +impl From for PartialValue { + fn from(value: PV) -> Self { + value.0 + } +} + +impl From for PV { + fn from(inner: ValueHandle) -> Self { + Self(inner.into()) + } +} + +impl Lattice for PV { + fn meet(self, other: Self) -> Self { + self.0.meet(other.0).into() + } + + fn meet_mut(&mut self, other: Self) -> bool { + self.0.meet_mut(other.0) + } + + fn join(self, other: Self) -> Self { + self.0.join(other.0).into() + } + + fn join_mut(&mut self, other: Self) -> bool { + self.0.join_mut(other.0) + } +} + +impl BoundedLattice for PV { + fn bottom() -> Self { + PartialValue::bottom().into() + } + + fn top() -> Self { + PartialValue::top().into() + } +} + +#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] +pub struct ValueRow(Vec); + +impl ValueRow { + fn new(len: usize) -> Self { + Self(vec![PV::bottom(); len]) + } + + fn singleton(len: usize, idx: usize, v: PV) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { + Self::singleton(r.len(), idx, v) + } + + fn bottom_from_row(r: &TypeRow) -> Self { + Self::new(r.len()) + } + + pub fn iter<'b>( + &'b self, + h: &'b impl HugrView, + n: Node, + ) -> impl Iterator + 'b { + zip_eq(value_inputs(h, n), self.0.iter()) + } + + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } +} + +impl Lattice for ValueRow { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PV; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec: Index, +{ + type Output = as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { + if let Some(sig) = h.signature(n) { + ValueRow::new(sig.input_count()) + } else { + ValueRow::new(0) + } +} + +pub(super) fn singleton_in_row(h: &impl HugrView, n: &Node, ip: &IncomingPort, v: PV) -> ValueRow { + let Some(sig) = h.signature(*n) else { + panic!("dougrulz"); + }; + if sig.input_count() <= ip.index() { + panic!( + "bad port index: {} >= {}: {}", + ip.index(), + sig.input_count(), + h.get_optype(*n).description() + ); + } + ValueRow::singleton_from_row(&h.signature(*n).unwrap().input, ip.index(), v) +} + +pub(super) fn partial_value_from_load_constant(h: &impl HugrView, node: Node) -> PV { + let load_op = h.get_optype(node).as_load_constant().unwrap(); + let const_node = h + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = h.get_optype(const_node).as_const().unwrap(); + ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())).into() +} + +pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { + PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum IO { + Input, + Output, +} + +pub(super) fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { + h.in_value_types(n).map(|x| x.0) +} + +pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { + h.out_value_types(n).map(|x| x.0) +} + +// We have several cases where sum types propagate to different places depending +// on their variant tag: +// - From the input of a conditional to the inputs of it's case nodes +// - From the input of the output node of a tail loop to the output of the input node of the tail loop +// - From the input of the output node of a tail loop to the output of tail loop node +// - From the input of a the output node of a dataflow block to the output of the input node of a dataflow block +// - From the input of a the output node of a dataflow block to the output of the cfg +// +// For a value `v` on an incoming porg `output_p`, compute the (out port,value) +// pairs that should be propagated for a given variant tag. We must also supply +// the length of this variant because it cannot always be deduced from the other +// inputs. +// +// If `v` does not support `variant_tag`, then all propagated values will be bottom.` +// +// If `output_p.index()` is 0 then the result is the contents of the variant. +// Otherwise, it is the single "other_output". +// +// TODO doctests +pub(super) fn outputs_for_variant<'a>( + output_p: IncomingPort, + variant_tag: usize, + variant_len: usize, + v: &'a PV, +) -> impl Iterator + 'a { + if output_p.index() == 0 { + Either::Left( + (0..variant_len).map(move |i| (i.into(), v.variant_field_value(variant_tag, i))), + ) + } else { + let v = if v.supports_tag(variant_tag) { + v.clone() + } else { + PV::bottom() + }; + Either::Right(std::iter::once(( + (variant_len + output_p.index() - 1).into(), + v, + ))) + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +#[cfg_attr(test, derive(Arbitrary))] +pub enum TailLoopTermination { + Bottom, + ExactlyZeroContinues, + Top, +} + +impl TailLoopTermination { + pub fn from_control_value(v: &PV) -> Self { + let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); + if may_break && !may_continue { + Self::ExactlyZeroContinues + } else if may_break && may_continue { + Self::top() + } else { + Self::bottom() + } + } +} + +impl PartialOrd for TailLoopTermination { + fn partial_cmp(&self, other: &Self) -> Option { + if self == other { + return Some(std::cmp::Ordering::Equal); + }; + match (self, other) { + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + _ => None, + } + } +} + +impl Lattice for TailLoopTermination { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn meet_mut(&mut self, other: Self) -> bool { + // let new_self = &mut self; + match (*self).partial_cmp(&other) { + Some(Ordering::Greater) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Bottom; + true + } + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match (*self).partial_cmp(&other) { + Some(Ordering::Less) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Top; + true + } + } + } +} + +impl BoundedLattice for TailLoopTermination { + fn bottom() -> Self { + Self::Bottom + } + + fn top() -> Self { + Self::Top + } +} + +#[cfg(test)] +#[cfg_attr(test, allow(non_local_definitions))] +mod test { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn bounded_lattice(v: TailLoopTermination) { + prop_assert!(v <= TailLoopTermination::top()); + prop_assert!(v >= TailLoopTermination::bottom()); + } + + #[test] + fn meet_join_self_noop(v1: TailLoopTermination) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } + } +} diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs new file mode 100644 index 000000000..0442aa4c9 --- /dev/null +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -0,0 +1,454 @@ +#![allow(missing_docs)] +use std::cmp::Ordering; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; + +use itertools::{zip_eq, Itertools as _}; + +use crate::ops::Value; +use crate::types::{Type, TypeEnum}; + +mod value_handle; + +pub use value_handle::{ValueHandle, ValueKey}; + +// TODO ALAN inline into PartialValue +#[derive(PartialEq, Clone, Eq)] +struct PartialSum(HashMap>); + +impl PartialSum { + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + Self([(tag, values.into_iter().collect())].into_iter().collect()) + } + + pub fn num_variants(&self) -> usize { + self.0.len() + } + + fn assert_variants(&self) { + assert_ne!(self.num_variants(), 0); + for pv in self.0.values().flat_map(|x| x.iter()) { + pv.assert_invariants(); + } + } + + pub fn variant_field_value(&self, variant: usize, idx: usize) -> PartialValue { + if let Some(row) = self.0.get(&variant) { + assert!(row.len() > idx); + row[idx].clone() + } else { + PartialValue::bottom() + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), + Err(_) => Err(self), + } + } + + // unsafe because we panic if any common rows have different lengths + fn join_mut_unsafe(&mut self, other: Self) -> bool { + let mut changed = false; + + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + self.0.insert(k, v); + changed = true; + } + } + changed + } + + // unsafe because we panic if any common rows have different lengths + fn meet_mut_unsafe(&mut self, other: Self) -> bool { + let mut changed = false; + let mut keys_to_remove = vec![]; + for k in self.0.keys() { + if !other.0.contains_key(k) { + keys_to_remove.push(*k); + } + } + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.meet_mut(rhs); + } + } else { + keys_to_remove.push(k); + } + } + for k in keys_to_remove { + self.0.remove(&k); + changed = true; + } + changed + } + + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } +} + +impl PartialOrd for PartialSum { + fn partial_cmp(&self, other: &Self) -> Option { + let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); + let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); + for k in self.0.keys() { + keys1[*k] = 1; + } + + for k in other.0.keys() { + keys2[*k] = 1; + } + + if let Some(ord) = keys1.partial_cmp(&keys2) { + if ord != Ordering::Equal { + return Some(ord); + } + } else { + return None; + } + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(&k) else { + unreachable!() + }; + match lhs.partial_cmp(rhs) { + Some(Ordering::Equal) => continue, + x => { + return x; + } + } + } + Some(Ordering::Equal) + } +} + +impl std::fmt::Debug for PartialSum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Hash for PartialSum { + fn hash(&self, state: &mut H) { + for (k, v) in &self.0 { + k.hash(state); + v.hash(state); + } + } +} + +impl TryFrom for PartialSum { + type Error = ValueHandle; + + fn try_from(value: ValueHandle) -> Result { + match value.value() { + Value::Tuple { vs } => { + let vec = (0..vs.len()) + .map(|i| PartialValue::from(value.index(i)).into()) + .collect(); + return Ok(Self([(0, vec)].into_iter().collect())); + } + Value::Sum { tag, values, .. } => { + let vec = (0..values.len()) + .map(|i| PartialValue::from(value.index(i)).into()) + .collect(); + return Ok(Self([(*tag, vec)].into_iter().collect())); + } + _ => (), + }; + Err(value) + } +} + +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + Bottom, + Value(ValueHandle), + PartialSum(PartialSum), + Top, +} + +impl From for PartialValue { + fn from(v: ValueHandle) -> Self { + TryInto::::try_into(v).map_or_else(Self::Value, Self::PartialSum) + } +} + +impl From for PartialValue { + fn from(v: PartialSum) -> Self { + Self::PartialSum(v) + } +} + +impl PartialValue { + // const BOTTOM: Self = Self::Bottom; + // const BOTTOM_REF: &'static Self = &Self::BOTTOM; + + // fn initialised(&self) -> bool { + // !self.is_top() + // } + + // fn is_top(&self) -> bool { + // self == &PartialValue::Top + // } + + fn assert_invariants(&self) { + match self { + Self::PartialSum(ps) => { + ps.assert_variants(); + } + Self::Value(v) => { + assert!(matches!(v.clone().into(), Self::Value(_))) + } + _ => {} + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let r = match self { + Self::Value(v) => Ok(v.value().clone()), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + }?; + assert_eq!(typ, &r.get_type()); + Ok(r) + } + + fn join_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + self.assert_invariants(); + match &*self { + Self::Top => return false, + Self::Value(v) if v == &vh => return false, + Self::Value(v) => { + *self = Self::Top; + } + Self::PartialSum(_) => match vh.into() { + Self::Value(_) => { + *self = Self::Top; + } + other => return self.join_mut(other), + }, + Self::Bottom => { + *self = vh.into(); + } + }; + true + } + + fn meet_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + self.assert_invariants(); + match &*self { + Self::Bottom => false, + Self::Value(v) => { + if v == &vh { + false + } else { + *self = Self::Bottom; + true + } + } + Self::PartialSum(_) => match vh.into() { + Self::Value(_) => { + *self = Self::Bottom; + true + } + other => self.meet_mut(other), + }, + Self::Top => { + *self = vh.into(); + true + } + } + } + + pub fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + pub fn join_mut(&mut self, other: Self) -> bool { + // println!("join {self:?}\n{:?}", &other); + let changed = match (&*self, other) { + (Self::Top, _) => false, + (_, other @ Self::Top) => { + *self = other; + true + } + (_, Self::Bottom) => false, + (Self::Bottom, other) => { + *self = other; + true + } + (Self::Value(h1), Self::Value(h2)) => { + if h1 == &h2 { + false + } else { + *self = Self::Top; + true + } + } + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { + unreachable!() + }; + ps1.join_mut_unsafe(ps2) + } + (Self::Value(_), mut other) => { + std::mem::swap(self, &mut other); + let Self::Value(old_self) = other else { + unreachable!() + }; + self.join_mut_value_handle(old_self) + } + (_, Self::Value(h)) => self.join_mut_value_handle(h), + // (new_self, _) => { + // **new_self = Self::Top; + // false + // } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + pub fn meet_mut(&mut self, other: Self) -> bool { + let changed = match (&*self, other) { + (Self::Bottom, _) => false, + (_, other @ Self::Bottom) => { + *self = other; + true + } + (_, Self::Top) => false, + (Self::Top, other) => { + *self = other; + true + } + (Self::Value(h1), Self::Value(h2)) => { + if h1 == &h2 { + false + } else { + *self = Self::Bottom; + true + } + } + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { + unreachable!() + }; + ps1.meet_mut_unsafe(ps2) + } + (Self::Value(_), mut other @ Self::PartialSum(_)) => { + std::mem::swap(self, &mut other); + let Self::Value(old_self) = other else { + unreachable!() + }; + self.meet_mut_value_handle(old_self) + } + (Self::PartialSum(_), Self::Value(h)) => self.meet_mut_value_handle(h), + // (new_self, _) => { + // **new_self = Self::Bottom; + // false + // } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn top() -> Self { + Self::Top + } + + pub fn bottom() -> Self { + Self::Bottom + } + + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::variant(tag, values).into() + } + + pub fn unit() -> Self { + Self::variant(0, []) + } + + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom => false, + PartialValue::Value(v) => v.tag() == tag, // can never be a sum or tuple + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// TODO docs + /// just delegate to variant_field_value + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO docs + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + match self { + Self::Bottom => Self::Bottom, + Self::PartialSum(ps) => ps.variant_field_value(variant, idx), + Self::Value(v) => { + if v.tag() == variant { + Self::Value(v.index(idx)) + } else { + Self::Bottom + } + } + Self::Top => Self::Top, + } + } +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), + _ => None, + } + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs new file mode 100644 index 000000000..35fbf5373 --- /dev/null +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -0,0 +1,346 @@ +use std::sync::Arc; + +use itertools::{zip_eq, Either, Itertools as _}; +use lazy_static::lazy_static; +use proptest::prelude::*; + +use crate::{ + ops::Value, + std_extensions::arithmetic::int_types::{ + self, get_log_width, ConstInt, INT_TYPES, LOG_WIDTH_BOUND, + }, + types::{CustomType, Type, TypeEnum}, +}; + +use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; +impl Arbitrary for ValueHandle { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + // prop_oneof![ + + // ] + todo!() + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum TestSumLeafType { + Int(Type), + Unit, +} + +impl TestSumLeafType { + fn assert_invariants(&self) { + match self { + Self::Int(t) => { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); + } + } + _ => (), + } + } + + fn get_type(&self) -> Type { + match self { + Self::Int(t) => t.clone(), + Self::Unit => Type::UNIT, + } + } + + fn type_check(&self, ps: &PartialSum) -> bool { + match self { + Self::Int(_) => false, + Self::Unit => { + if let Ok((0, v)) = ps.0.iter().exactly_one() { + v.is_empty() + } else { + false + } + } + } + } + + fn partial_value_strategy(self) -> impl Strategy { + match self { + Self::Int(t) => { + let TypeEnum::Extension(ct) = t.as_type_enum() else { + unreachable!() + }; + let lw = get_log_width(&ct.args()[0]).unwrap(); + (0u64..(1 << (2u64.pow(lw as u32) - 1))) + .prop_map(move |x| { + let ki = ConstInt::new_u(lw, x).unwrap(); + ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() + }) + .boxed() + } + Self::Unit => Just(PartialSum::unit().into()).boxed(), + } + } +} + +impl Arbitrary for TestSumLeafType { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + let int_strat = (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); + prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum TestSumType { + Branch(usize, Vec>>), + Leaf(TestSumLeafType), +} + +impl TestSumType { + const UNIT: TestSumLeafType = TestSumLeafType::Unit; + + fn leaf(v: Type) -> Self { + TestSumType::Leaf(TestSumLeafType::Int(v)) + } + + fn branch(vs: impl IntoIterator>>) -> Self { + let vec = vs.into_iter().collect_vec(); + let depth: usize = vec + .iter() + .flat_map(|x| x.iter()) + .map(|x| x.depth() + 1) + .max() + .unwrap_or(0); + Self::Branch(depth, vec) + } + + fn depth(&self) -> usize { + match self { + TestSumType::Branch(x, _) => *x, + TestSumType::Leaf(_) => 0, + } + } + + fn is_leaf(&self) -> bool { + self.depth() == 0 + } + + fn assert_invariants(&self) { + match self { + TestSumType::Branch(d, sop) => { + assert!(!sop.is_empty(), "No variants"); + for v in sop.iter().flat_map(|x| x.iter()) { + assert!(v.depth() < *d); + v.assert_invariants(); + } + } + TestSumType::Leaf(l) => { + l.assert_invariants(); + } + _ => (), + } + } + + fn select(self) -> impl Strategy>)>> { + match self { + TestSumType::Branch(_, sop) => any::() + .prop_map(move |i| { + let index = i.index(sop.len()); + Either::Right((index, sop[index].clone())) + }) + .boxed(), + TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), + } + } + + fn get_type(&self) -> Type { + match self { + TestSumType::Branch(_, sop) => Type::new_sum( + sop.iter() + .map(|row| row.iter().map(|x| x.get_type()).collect_vec().into()), + ), + TestSumType::Leaf(l) => l.get_type(), + } + } + + fn type_check(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), + (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { + return false; + } + } + true + } + (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), + } + } +} + +impl From for TestSumType { + fn from(value: TestSumLeafType) -> Self { + Self::Leaf(value) + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +struct UnarySumTypeParams { + depth: usize, + branch_width: usize, +} + +impl UnarySumTypeParams { + pub fn descend(mut self, d: usize) -> Self { + assert!(d < self.depth); + self.depth = d; + self + } +} + +impl Default for UnarySumTypeParams { + fn default() -> Self { + Self { + depth: 3, + branch_width: 3, + } + } +} + +impl Arbitrary for TestSumType { + type Parameters = UnarySumTypeParams; + type Strategy = BoxedStrategy; + fn arbitrary_with( + params @ UnarySumTypeParams { + depth, + branch_width, + }: Self::Parameters, + ) -> Self::Strategy { + if depth == 0 { + any::().prop_map_into().boxed() + } else { + (0..depth) + .prop_flat_map(move |d| { + prop::collection::vec( + prop::collection::vec( + any_with::(params.clone().descend(d)).prop_map_into(), + 0..branch_width, + ), + 1..=branch_width, + ) + .prop_map(TestSumType::branch) + }) + .boxed() + } + } +} + +proptest! { + #[test] + fn unary_sum_type_valid(ust: TestSumType) { + ust.assert_invariants(); + } +} + +fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { + ust.select().prop_flat_map(|x| match x { + Either::Left(l) => l.partial_value_strategy().boxed(), + Either::Right((index, usts)) => { + let pvs = usts + .into_iter() + .map(|x| { + any_partial_value_of_type( + Arc::::try_unwrap(x).unwrap_or_else(|x| x.as_ref().clone()), + ) + }) + .collect_vec(); + pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + .boxed() + } + }) +} + +fn any_partial_value_with( + params: ::Parameters, +) -> impl Strategy { + any_with::(params).prop_flat_map(any_partial_value_of_type) +} + +fn any_partial_value() -> impl Strategy { + any_partial_value_with(Default::default()) +} + +fn any_partial_values() -> impl Strategy { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(ust.clone())) + .collect_vec(), + ) + .unwrap() + }) +} + +fn any_typed_partial_value() -> impl Strategy { + any::() + .prop_flat_map(|t| any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v))) +} + +proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.type_check(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); + } + + #[test] + fn meet_join_self_noop(v1 in any_partial_value()) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } +} diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs new file mode 100644 index 000000000..dfb019872 --- /dev/null +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -0,0 +1,245 @@ +use std::any::Any; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +use downcast_rs::Downcast; +use itertools::Either; + +use crate::ops::Value; +use crate::std_extensions::arithmetic::int_types::ConstInt; +use crate::Node; + +pub trait ValueName: std::fmt::Debug + Downcast + Any { + fn hash(&self) -> u64; + fn eq(&self, other: &dyn ValueName) -> bool; +} + +fn hash_hash(x: &impl Hash) -> u64 { + let mut hasher = DefaultHasher::new(); + x.hash(&mut hasher); + hasher.finish() +} + +fn value_name_eq(x: &T, other: &dyn ValueName) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + x == other + } else { + false + } +} + +impl ValueName for String { + fn hash(&self) -> u64 { + hash_hash(self) + } + + fn eq(&self, other: &dyn ValueName) -> bool { + value_name_eq(self, other) + } +} + +impl ValueName for ConstInt { + fn hash(&self) -> u64 { + hash_hash(self) + } + + fn eq(&self, other: &dyn ValueName) -> bool { + value_name_eq(self, other) + } +} + +#[derive(Clone, Debug)] +pub struct ValueKey(Vec, Either>); + +impl PartialEq for ValueKey { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + && match (&self.1, &other.1) { + (Either::Left(ref n1), Either::Left(ref n2)) => n1 == n2, + (Either::Right(ref v1), Either::Right(ref v2)) => v1.eq(v2.as_ref()), + _ => false, + } + } +} + +impl Eq for ValueKey {} + +impl Hash for ValueKey { + fn hash(&self, state: &mut H) { + self.0.hash(state); + match &self.1 { + Either::Left(n) => (0, n).hash(state), + Either::Right(v) => (1, v.hash()).hash(state), + } + } +} + +impl From for ValueKey { + fn from(n: Node) -> Self { + Self(vec![], Either::Left(n)) + } +} + +impl ValueKey { + pub fn new(k: impl ValueName) -> Self { + Self(vec![], Either::Right(Arc::new(k))) + } + + pub fn index(self, i: usize) -> Self { + let mut is = self.0; + is.push(i); + Self(is, self.1) + } +} + +#[derive(Clone, Debug)] +pub struct ValueHandle(ValueKey, Arc); + +impl ValueHandle { + pub fn new(key: ValueKey, value: Arc) -> Self { + Self(key, value) + } + + pub fn value(&self) -> &Value { + self.1.as_ref() + } + + pub fn is_compound(&self) -> bool { + match self.value() { + Value::Sum { .. } | Value::Tuple { .. } => true, + _ => false, + } + } + + pub fn num_fields(&self) -> usize { + assert!( + self.is_compound(), + "ValueHandle::num_fields called on non-Sum, non-Tuple value: {:#?}", + self + ); + match self.value() { + Value::Sum { values, .. } => values.len(), + Value::Tuple { vs } => vs.len(), + _ => unreachable!(), + } + } + + pub fn tag(&self) -> usize { + assert!( + self.is_compound(), + "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", + self + ); + match self.value() { + Value::Sum { tag, .. } => *tag, + Value::Tuple { .. } => 0, + _ => unreachable!(), + } + } + + pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { + assert!( + i < self.num_fields(), + "ValueHandle::index called with out-of-bounds index {}: {:#?}", + i, + &self + ); + let vs = match self.value() { + Value::Sum { values, .. } => values, + Value::Tuple { vs, .. } => vs, + _ => unreachable!(), + }; + let v = vs[i].clone().into(); + Self(self.0.clone().index(i), v) + } +} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + // If the keys are equal, we return true since the values must have the + // same provenance, and so be equal. If the keys are different but the + // values are equal, we could return true if we didn't impl Eq, but + // since we do impl Eq, the Hash contract prohibits us from having equal + // values with different hashes. + let r = self.0 == other.0; + if r { + debug_assert_eq!(self.get_type(), other.get_type()); + } + r + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + self.0.hash(state); + } +} + +/// TODO this is perhaps dodgy +/// we do not hash or compare the value, just the key +/// this means two handles with different keys, but with the same value, will +/// not compare equal. +impl Deref for ValueHandle { + type Target = Value; + + fn deref(&self) -> &Self::Target { + self.value() + } +} + +#[cfg(test)] +mod test { + use crate::{ops::constant::CustomConst as _, types::SumType}; + + use super::*; + + #[test] + fn value_key_eq() { + let k1 = ValueKey::new("foo".to_string()); + let k2 = ValueKey::new("foo".to_string()); + let k3 = ValueKey::new("bar".to_string()); + + assert_eq!(k1, k2); + assert_ne!(k1, k3); + + let k4: ValueKey = From::::from(portgraph::NodeIndex::new(1).into()); + let k5 = From::::from(portgraph::NodeIndex::new(1).into()); + let k6 = From::::from(portgraph::NodeIndex::new(2).into()); + + assert_eq!(&k4, &k5); + assert_ne!(&k4, &k6); + + let k7 = k5.clone().index(3); + let k4 = k4.index(3); + + assert_eq!(&k4, &k7); + + let k5 = k5.index(2); + + assert_ne!(&k5, &k7); + } + + #[test] + fn value_handle_eq() { + let k_i = ConstInt::new_u(4, 2).unwrap(); + let subject_val = Arc::new( + Value::sum( + 0, + [k_i.clone().into()], + SumType::new([vec![k_i.get_type()], vec![]]), + ) + .unwrap(), + ); + + let k1 = ValueKey::new("foo".to_string()); + let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); + let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); + + // we do not compare the value, just the key + assert_ne!(v1.index(0), v2); + assert_eq!(v1.index(0).value(), v2.value()); + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 13dd47776..8949d8bd4 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod const_fold2; pub mod force_order; mod half_node; pub mod lower; From ac45e53ed2346eec5604ceb79f15ca4cff180a2f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 16:16:29 +0100 Subject: [PATCH 002/281] merge/update+fmt (ValueName for ConstInt non-compiling as ConstInt not Hash) --- hugr-passes/Cargo.toml | 5 ++++ hugr-passes/src/const_fold2.rs | 2 +- hugr-passes/src/const_fold2/datalog.rs | 6 ++-- .../src/const_fold2/datalog/context.rs | 1 - hugr-passes/src/const_fold2/datalog/test.rs | 27 ++++++++--------- hugr-passes/src/const_fold2/datalog/utils.rs | 6 ++-- hugr-passes/src/const_fold2/partial_value.rs | 19 ++++++------ .../src/const_fold2/partial_value/test.rs | 17 ++++++----- .../const_fold2/partial_value/value_handle.rs | 29 ++++++++----------- 9 files changed, 54 insertions(+), 58 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index f0b09516d..a6ed580c3 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -14,6 +14,9 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.9.1" } +portgraph = { workspace = true } +ascent = "0.6.0" +downcast-rs = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } @@ -25,3 +28,5 @@ extension_inference = ["hugr-core/extension_inference"] [dev-dependencies] rstest = { workspace = true } +proptest = { workspace = true } +proptest-derive = { workspace = true } diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index dbe4464fd..96af004e1 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,2 +1,2 @@ mod datalog; -pub mod partial_value; \ No newline at end of file +pub mod partial_value; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index d7df9c1e6..0aca8e9b8 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,14 +1,12 @@ use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; -use delegate::delegate; use itertools::{zip_eq, Itertools}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::{Arc, Mutex}; -use either::Either; +use super::partial_value::{PartialValue, ValueHandle, ValueKey}; use hugr_core::ops::{OpTag, OpTrait, Value}; -use hugr_core::partial_value::{PartialValue, ValueHandle, ValueKey}; -use hugr_core::types::{EdgeKind, FunctionType, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{EdgeKind, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 92c0c3285..9117cc429 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -6,7 +6,6 @@ use std::sync::{Arc, Mutex}; use hugr_core::hugr::internal::HugrInternals; use hugr_core::ops::Value; -use hugr_core::partial_value::{ValueHandle, ValueKey}; use hugr_core::{Hugr, HugrView, Node}; use super::DFContext; diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 4e086c4b7..5e70bf8b4 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -3,17 +3,17 @@ use hugr_core::{ extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, type_row, - types::{FunctionType, SumType}, + types::{Signature, SumType}, Extension, }; -use hugr_core::partial_value::PartialValue; +use crate::const_fold2::partial_value::PartialValue; use super::*; #[test] fn test_make_tuple() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); @@ -28,7 +28,7 @@ fn test_make_tuple() { #[test] fn test_unpack_tuple() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); @@ -49,7 +49,7 @@ fn test_unpack_tuple() { #[test] fn test_unpack_const() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); let [o] = builder .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) @@ -66,7 +66,7 @@ fn test_unpack_const() { #[test] fn test_tail_loop_never_iterates() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let r_v = Value::unit_sum(3, 6).unwrap(); let r_w = builder.add_load_value( Value::sum( @@ -98,7 +98,7 @@ fn test_tail_loop_never_iterates() { #[test] fn test_tail_loop_always_iterates() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let r_w = builder .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); let true_w = builder.add_load_value(Value::true_val()); @@ -130,7 +130,7 @@ fn test_tail_loop_always_iterates() { #[test] fn test_tail_loop_iterates_twice() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); let true_w = builder.add_load_value(Value::true_val()); @@ -143,7 +143,7 @@ fn test_tail_loop_iterates_twice() { .unwrap(); assert_eq!( tlb.loop_signature().unwrap().dataflow_signature().unwrap(), - FunctionType::new_endo(type_row![BOOL_T, BOOL_T]) + Signature::new_endo(type_row![BOOL_T, BOOL_T]) ); let [in_w1, in_w2] = tlb.input_wires_arr(); let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); @@ -180,18 +180,15 @@ fn test_tail_loop_iterates_twice() { fn conditional() { let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; let cond_t = Type::new_sum(variants.clone()); - let mut builder = DFGBuilder::new(FunctionType::new( - Into::::into(cond_t), - type_row![], - )) - .unwrap(); + let mut builder = + DFGBuilder::new(Signature::new(Into::::into(cond_t), type_row![])).unwrap(); let [arg_w] = builder.input_wires_arr(); let true_w = builder.add_load_value(Value::true_val()); let false_w = builder.add_load_value(Value::false_val()); let mut cond_builder = builder - .conditional_builder( + .conditional_builder_exts( (variants, arg_w), [(BOOL_T, true_w)], type_row!(BOOL_T, BOOL_T), diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 9c2e46ae3..31162a718 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -6,14 +6,14 @@ use std::{cmp::Ordering, ops::Index, sync::Arc}; use ascent::lattice::{BoundedLattice, Lattice}; -use either::Either; +use itertools::{zip_eq, Either}; + +use crate::const_fold2::partial_value::{PartialValue, ValueHandle}; use hugr_core::{ ops::OpTrait as _, - partial_value::{PartialValue, ValueHandle}, types::{EdgeKind, TypeRow}, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; -use itertools::zip_eq; #[cfg(test)] use proptest_derive::Arbitrary; diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 0442aa4c9..dafc48fce 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -3,10 +3,11 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +use hugr_core::ops::constant::Sum; use itertools::{zip_eq, Itertools as _}; -use crate::ops::Value; -use crate::types::{Type, TypeEnum}; +use hugr_core::ops::Value; +use hugr_core::types::{Type, TypeEnum, TypeRow}; mod value_handle; @@ -17,6 +18,9 @@ pub use value_handle::{ValueHandle, ValueKey}; struct PartialSum(HashMap>); impl PartialSum { + pub fn unit() -> Self { + Self::variant(0, []) + } pub fn variant(tag: usize, values: impl IntoIterator) -> Self { Self([(tag, values.into_iter().collect())].into_iter().collect()) } @@ -52,6 +56,9 @@ impl PartialSum { let Some(r) = st.get_variant(*k) else { Err(self)? }; + let Ok(r): Result = r.clone().try_into() else { + Err(self)? + }; if v.len() != r.len() { return Err(self); } @@ -165,13 +172,7 @@ impl TryFrom for PartialSum { fn try_from(value: ValueHandle) -> Result { match value.value() { - Value::Tuple { vs } => { - let vec = (0..vs.len()) - .map(|i| PartialValue::from(value.index(i)).into()) - .collect(); - return Ok(Self([(0, vec)].into_iter().collect())); - } - Value::Sum { tag, values, .. } => { + Value::Sum(Sum { tag, values, .. }) => { let vec = (0..values.len()) .map(|i| PartialValue::from(value.index(i)).into()) .collect(); diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs index 35fbf5373..227d7aff7 100644 --- a/hugr-passes/src/const_fold2/partial_value/test.rs +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -4,12 +4,10 @@ use itertools::{zip_eq, Either, Itertools as _}; use lazy_static::lazy_static; use proptest::prelude::*; -use crate::{ +use hugr_core::{ ops::Value, - std_extensions::arithmetic::int_types::{ - self, get_log_width, ConstInt, INT_TYPES, LOG_WIDTH_BOUND, - }, - types::{CustomType, Type, TypeEnum}, + std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, + types::{CustomType, Type, TypeArg, TypeEnum}, }; use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; @@ -71,10 +69,13 @@ impl TestSumLeafType { let TypeEnum::Extension(ct) = t.as_type_enum() else { unreachable!() }; - let lw = get_log_width(&ct.args()[0]).unwrap(); + // TODO this should be get_log_width, but that's not pub + let TypeArg::BoundedNat { n: lw } = ct.args()[0] else { + panic!() + }; (0u64..(1 << (2u64.pow(lw as u32) - 1))) .prop_map(move |x| { - let ki = ConstInt::new_u(lw, x).unwrap(); + let ki = ConstInt::new_u(lw as u8, x).unwrap(); ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() }) .boxed() @@ -160,7 +161,7 @@ impl TestSumType { match self { TestSumType::Branch(_, sop) => Type::new_sum( sop.iter() - .map(|row| row.iter().map(|x| x.get_type()).collect_vec().into()), + .map(|row| row.iter().map(|x| x.get_type()).collect_vec()), ), TestSumType::Leaf(l) => l.get_type(), } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index dfb019872..6a91d513a 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -4,11 +4,12 @@ use std::ops::Deref; use std::sync::Arc; use downcast_rs::Downcast; +use hugr_core::ops::constant::Sum; use itertools::Either; -use crate::ops::Value; -use crate::std_extensions::arithmetic::int_types::ConstInt; -use crate::Node; +use hugr_core::ops::Value; +use hugr_core::std_extensions::arithmetic::int_types::ConstInt; +use hugr_core::Node; pub trait ValueName: std::fmt::Debug + Downcast + Any { fn hash(&self) -> u64; @@ -106,10 +107,7 @@ impl ValueHandle { } pub fn is_compound(&self) -> bool { - match self.value() { - Value::Sum { .. } | Value::Tuple { .. } => true, - _ => false, - } + matches!(self.value(), Value::Sum(_)) } pub fn num_fields(&self) -> usize { @@ -119,8 +117,7 @@ impl ValueHandle { self ); match self.value() { - Value::Sum { values, .. } => values.len(), - Value::Tuple { vs } => vs.len(), + Value::Sum(Sum { values, .. }) => values.len(), _ => unreachable!(), } } @@ -132,8 +129,7 @@ impl ValueHandle { self ); match self.value() { - Value::Sum { tag, .. } => *tag, - Value::Tuple { .. } => 0, + Value::Sum(Sum { tag, .. }) => *tag, _ => unreachable!(), } } @@ -146,8 +142,7 @@ impl ValueHandle { &self ); let vs = match self.value() { - Value::Sum { values, .. } => values, - Value::Tuple { vs, .. } => vs, + Value::Sum(Sum { values, .. }) => values, _ => unreachable!(), }; let v = vs[i].clone().into(); @@ -192,7 +187,7 @@ impl Deref for ValueHandle { #[cfg(test)] mod test { - use crate::{ops::constant::CustomConst as _, types::SumType}; + use hugr_core::{ops::constant::CustomConst as _, types::SumType}; use super::*; @@ -205,9 +200,9 @@ mod test { assert_eq!(k1, k2); assert_ne!(k1, k3); - let k4: ValueKey = From::::from(portgraph::NodeIndex::new(1).into()); - let k5 = From::::from(portgraph::NodeIndex::new(1).into()); - let k6 = From::::from(portgraph::NodeIndex::new(2).into()); + let k4: ValueKey = Node::from(portgraph::NodeIndex::new(1)).into(); + let k5: ValueKey = Node::from(portgraph::NodeIndex::new(1)).into(); + let k6: ValueKey = Node::from(portgraph::NodeIndex::new(2)).into(); assert_eq!(&k4, &k5); assert_ne!(&k4, &k6); From 8adaa6e5ae809c958b65cedf0775efcdb1e15c66 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 14:58:55 +0100 Subject: [PATCH 003/281] Missing imports / lints. Now running, but failing w/StackOverflow --- hugr-passes/src/const_fold2/datalog.rs | 11 ++++------- hugr-passes/src/const_fold2/datalog/context.rs | 7 ++----- hugr-passes/src/const_fold2/datalog/test.rs | 7 +++---- hugr-passes/src/const_fold2/datalog/utils.rs | 4 +--- hugr-passes/src/const_fold2/partial_value/test.rs | 5 +---- 5 files changed, 11 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 0aca8e9b8..7e30b29e6 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,12 +1,9 @@ -use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; -use itertools::{zip_eq, Itertools}; +use ascent::lattice::BoundedLattice; use std::collections::HashMap; -use std::hash::{Hash, Hasher}; -use std::sync::{Arc, Mutex}; +use std::hash::Hash; -use super::partial_value::{PartialValue, ValueHandle, ValueKey}; -use hugr_core::ops::{OpTag, OpTrait, Value}; -use hugr_core::types::{EdgeKind, SumType, Type, TypeEnum, TypeRow}; +use super::partial_value::PartialValue; +use hugr_core::ops::Value; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 9117cc429..81e3709c4 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -1,12 +1,9 @@ -use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::ops::Deref; -use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use hugr_core::hugr::internal::HugrInternals; -use hugr_core::ops::Value; -use hugr_core::{Hugr, HugrView, Node}; +use hugr_core::{Hugr, HugrView}; use super::DFContext; diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 5e70bf8b4..bea8db857 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -3,8 +3,7 @@ use hugr_core::{ extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, type_row, - types::{Signature, SumType}, - Extension, + types::{Signature, SumType, Type, TypeRow}, }; use crate::const_fold2::partial_value::PartialValue; @@ -58,7 +57,7 @@ fn test_unpack_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - let c = machine.run_hugr(&hugr); + machine.run_hugr(&hugr); let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); assert_eq!(o_r, Value::true_val()); @@ -161,7 +160,7 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2, _] = tail_loop.outputs_arr(); let mut machine = Machine::new(); - let c = machine.run_hugr(&hugr); + machine.run_hugr(&hugr); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 31162a718..5c2b12730 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -10,9 +10,7 @@ use itertools::{zip_eq, Either}; use crate::const_fold2::partial_value::{PartialValue, ValueHandle}; use hugr_core::{ - ops::OpTrait as _, - types::{EdgeKind, TypeRow}, - HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, + ops::OpTrait as _, types::TypeRow, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; #[cfg(test)] diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs index 227d7aff7..5e3b861e3 100644 --- a/hugr-passes/src/const_fold2/partial_value/test.rs +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -1,13 +1,11 @@ use std::sync::Arc; use itertools::{zip_eq, Either, Itertools as _}; -use lazy_static::lazy_static; use proptest::prelude::*; use hugr_core::{ - ops::Value, std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, - types::{CustomType, Type, TypeArg, TypeEnum}, + types::{Type, TypeArg, TypeEnum}, }; use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; @@ -141,7 +139,6 @@ impl TestSumType { TestSumType::Leaf(l) => { l.assert_invariants(); } - _ => (), } } From 098c7350c58fcf9f74b794db5cec2ffdce62f60e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 17:43:51 +0100 Subject: [PATCH 004/281] Fix tests... * DFContext reinstate fn hugr(), drop AsRef requirement (fixes StackOverflow) * test_tail_loop_iterates_twice: use tail_loop_builder_exts, fix from #1332(?) * Fix only-one-DataflowContext asserts using Arc::ptr_eq --- hugr-passes/src/const_fold2/datalog.rs | 18 ++++++++++-------- .../src/const_fold2/datalog/context.rs | 19 ++++++++----------- hugr-passes/src/const_fold2/datalog/test.rs | 9 +++++++-- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 7e30b29e6..96c4dd50c 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -12,7 +12,9 @@ mod utils; use context::DataflowContext; pub use utils::{TailLoopTermination, ValueRow, IO, PV}; -pub trait DFContext: AsRef + Clone + Eq + Hash + std::ops::Deref {} +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + fn hugr(&self) -> &impl HugrView; +} ascent::ascent! { // The trait-indirection layer here means we can just write 'C' but in practice ATM @@ -34,9 +36,9 @@ ascent::ascent! { node(c, n) <-- context(c), for n in c.nodes(); - in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c, *n); + in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.hugr(), *n); - out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c, *n); + out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.hugr(), *n); parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.get_parent(*child); @@ -55,8 +57,8 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n); - node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, utils::bottom_row(c.hugr(), *n)) <-- node(c, n); + node_in_value_row(c, n, utils::singleton_in_row(c.hugr(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); // Per node-type rules @@ -67,7 +69,7 @@ ascent::ascent! { relation load_constant_node(C, Node); load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); - out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <-- + out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c.hugr(), *n)) <-- load_constant_node(c, n); @@ -116,7 +118,7 @@ ascent::ascent! { if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_inputs.len(), - for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map( |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) ); @@ -127,7 +129,7 @@ ascent::ascent! { if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_outputs.len(), - for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map( |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) ); diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 81e3709c4..1d77e39eb 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -2,7 +2,6 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::hugr::internal::HugrInternals; use hugr_core::{Hugr, HugrView}; use super::DFContext; @@ -25,13 +24,13 @@ impl Clone for DataflowContext { } impl Hash for DataflowContext { - fn hash(&self, state: &mut I) {} + fn hash(&self, _state: &mut I) {} } impl PartialEq for DataflowContext { fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DataflowContext - assert_eq!(self as *const _, other as *const _); + // Any AscentProgram should have only one DataflowContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); true } } @@ -40,8 +39,8 @@ impl Eq for DataflowContext {} impl PartialOrd for DataflowContext { fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DataflowContext - assert_eq!(self as *const _, other as *const _); + // Any AscentProgram should have only one DataflowContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); Some(std::cmp::Ordering::Equal) } } @@ -54,10 +53,8 @@ impl Deref for DataflowContext { } } -impl AsRef for DataflowContext { - fn as_ref(&self) -> &Hugr { - self.base_hugr() +impl DFContext for DataflowContext { + fn hugr(&self) -> &impl HugrView { + self.0.as_ref() } } - -impl DFContext for DataflowContext {} diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index bea8db857..783171525 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -138,7 +138,12 @@ fn test_tail_loop_iterates_twice() { // let r_w = builder // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); let tlb = builder - .tail_loop_builder([], [(BOOL_T, false_w), (BOOL_T, true_w)], vec![].into()) + .tail_loop_builder_exts( + [], + [(BOOL_T, false_w), (BOOL_T, true_w)], + vec![].into(), + ExtensionSet::new(), + ) .unwrap(); assert_eq!( tlb.loop_signature().unwrap().dataflow_signature().unwrap(), @@ -157,7 +162,7 @@ fn test_tail_loop_iterates_twice() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); // TODO once we can do conditionals put these wires inside `just_outputs` and // we should be able to propagate their values - let [o_w1, o_w2, _] = tail_loop.outputs_arr(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::new(); machine.run_hugr(&hugr); From 706c89208bea60c423cf7fa8960e5073499ccde6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 18:00:04 +0100 Subject: [PATCH 005/281] ValueKey using MaybeHash --- .../src/const_fold2/partial_value/test.rs | 3 +- .../const_fold2/partial_value/value_handle.rs | 153 +++++++++--------- 2 files changed, 83 insertions(+), 73 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs index 5e3b861e3..6621f0a69 100644 --- a/hugr-passes/src/const_fold2/partial_value/test.rs +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -74,7 +74,8 @@ impl TestSumLeafType { (0u64..(1 << (2u64.pow(lw as u32) - 1))) .prop_map(move |x| { let ki = ConstInt::new_u(lw as u8, x).unwrap(); - ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() + let k = ValueKey::try_new(ki.clone()).unwrap(); + ValueHandle::new(k, Arc::new(ki.into())).into() }) .boxed() } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 6a91d513a..5ffe5af21 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -1,96 +1,68 @@ -use std::any::Any; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use downcast_rs::Downcast; -use hugr_core::ops::constant::Sum; -use itertools::Either; +use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::Value; -use hugr_core::std_extensions::arithmetic::int_types::ConstInt; use hugr_core::Node; -pub trait ValueName: std::fmt::Debug + Downcast + Any { - fn hash(&self) -> u64; - fn eq(&self, other: &dyn ValueName) -> bool; -} - -fn hash_hash(x: &impl Hash) -> u64 { - let mut hasher = DefaultHasher::new(); - x.hash(&mut hasher); - hasher.finish() +#[derive(Clone, Debug)] +pub struct HashedConst { + hash: u64, + val: Arc, } -fn value_name_eq(x: &T, other: &dyn ValueName) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - x == other - } else { - false +impl PartialEq for HashedConst { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash && self.val.equal_consts(other.val.as_ref()) } } -impl ValueName for String { - fn hash(&self) -> u64 { - hash_hash(self) - } +impl Eq for HashedConst {} - fn eq(&self, other: &dyn ValueName) -> bool { - value_name_eq(self, other) +impl Hash for HashedConst { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); } } -impl ValueName for ConstInt { - fn hash(&self) -> u64 { - hash_hash(self) - } - - fn eq(&self, other: &dyn ValueName) -> bool { - value_name_eq(self, other) - } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum ValueKey { + Select(usize, Box), + Const(HashedConst), + Node(Node), } -#[derive(Clone, Debug)] -pub struct ValueKey(Vec, Either>); - -impl PartialEq for ValueKey { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - && match (&self.1, &other.1) { - (Either::Left(ref n1), Either::Left(ref n2)) => n1 == n2, - (Either::Right(ref v1), Either::Right(ref v2)) => v1.eq(v2.as_ref()), - _ => false, - } +impl From for ValueKey { + fn from(n: Node) -> Self { + Self::Node(n) } } -impl Eq for ValueKey {} - -impl Hash for ValueKey { - fn hash(&self, state: &mut H) { - self.0.hash(state); - match &self.1 { - Either::Left(n) => (0, n).hash(state), - Either::Right(v) => (1, v.hash()).hash(state), - } +impl From for ValueKey { + fn from(value: HashedConst) -> Self { + Self::Const(value) } } -impl From for ValueKey { - fn from(n: Node) -> Self { - Self(vec![], Either::Left(n)) +impl ValueKey { + pub fn new(n: Node, k: impl CustomConst) -> Self { + Self::try_new(k).unwrap_or(Self::Node(n)) } -} -impl ValueKey { - pub fn new(k: impl ValueName) -> Self { - Self(vec![], Either::Right(Arc::new(k))) + pub fn try_new(cst: impl CustomConst) -> Option { + let mut hasher = DefaultHasher::new(); + cst.maybe_hash(&mut hasher).then(|| { + Self::Const(HashedConst { + hash: hasher.finish(), + val: Arc::new(cst), + }) + }) } pub fn index(self, i: usize) -> Self { - let mut is = self.0; - is.push(i); - Self(is, self.1) + Self::Select(i, Box::new(self)) } } @@ -187,22 +159,40 @@ impl Deref for ValueHandle { #[cfg(test)] mod test { - use hugr_core::{ops::constant::CustomConst as _, types::SumType}; + use hugr_core::{ + extension::prelude::ConstString, + ops::constant::CustomConst as _, + std_extensions::{ + arithmetic::{ + float_types::{ConstF64, FLOAT64_TYPE}, + int_types::{ConstInt, INT_TYPES}, + }, + collections::ListValue, + }, + types::SumType, + }; use super::*; #[test] fn value_key_eq() { - let k1 = ValueKey::new("foo".to_string()); - let k2 = ValueKey::new("foo".to_string()); - let k3 = ValueKey::new("bar".to_string()); + let n = Node::from(portgraph::NodeIndex::new(0)); + let n2: Node = portgraph::NodeIndex::new(1).into(); + let k1 = ValueKey::new(n, ConstString::new("foo".to_string())); + let k2 = ValueKey::new(n2, ConstString::new("foo".to_string())); + let k3 = ValueKey::new(n, ConstString::new("bar".to_string())); - assert_eq!(k1, k2); + assert_eq!(k1, k2); // Node ignored assert_ne!(k1, k3); - let k4: ValueKey = Node::from(portgraph::NodeIndex::new(1)).into(); - let k5: ValueKey = Node::from(portgraph::NodeIndex::new(1)).into(); - let k6: ValueKey = Node::from(portgraph::NodeIndex::new(2)).into(); + assert_eq!(ValueKey::from(n), ValueKey::from(n)); + let f = ConstF64::new(3.141); + assert_eq!(ValueKey::new(n, f.clone()), ValueKey::from(n)); + + assert_ne!(ValueKey::new(n, f.clone()), ValueKey::new(n2, f)); // Node taken into account + let k4 = ValueKey::from(n); + let k5 = ValueKey::from(n); + let k6: ValueKey = ValueKey::from(n2); assert_eq!(&k4, &k5); assert_ne!(&k4, &k6); @@ -217,6 +207,25 @@ mod test { assert_ne!(&k5, &k7); } + #[test] + fn value_key_list() { + let v1 = ConstInt::new_u(3, 3).unwrap(); + let v2 = ConstInt::new_u(4, 3).unwrap(); + let v3 = ConstF64::new(3.141); + + let n = Node::from(portgraph::NodeIndex::new(0)); + let n2: Node = portgraph::NodeIndex::new(1).into(); + + let lst = ListValue::new(INT_TYPES[0].clone(), [v1.into(), v2.into()]); + assert_eq!(ValueKey::new(n, lst.clone()), ValueKey::new(n2, lst)); + + let lst = ListValue::new(FLOAT64_TYPE, [v3.into()]); + assert_ne!( + ValueKey::new(n, lst.clone()), + ValueKey::new(n2, lst.clone()) + ); + } + #[test] fn value_handle_eq() { let k_i = ConstInt::new_u(4, 2).unwrap(); @@ -229,7 +238,7 @@ mod test { .unwrap(), ); - let k1 = ValueKey::new("foo".to_string()); + let k1 = ValueKey::try_new(ConstString::new("foo".to_string())).unwrap(); let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); From 63bc944c86dcf98d98414fdfbf117d27817bd6dc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 20:43:19 +0100 Subject: [PATCH 006/281] tag() does not refer to self.is_compound --- hugr-passes/src/const_fold2/partial_value.rs | 1 - hugr-passes/src/const_fold2/partial_value/value_handle.rs | 8 ++------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index dafc48fce..b5018ce38 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -412,7 +412,6 @@ impl PartialValue { } /// TODO docs - /// just delegate to variant_field_value pub fn tuple_field_value(&self, idx: usize) -> Self { self.variant_field_value(0, idx) } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 5ffe5af21..ff3a1fa16 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -95,14 +95,10 @@ impl ValueHandle { } pub fn tag(&self) -> usize { - assert!( - self.is_compound(), - "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", - self - ); match self.value() { Value::Sum(Sum { tag, .. }) => *tag, - _ => unreachable!(), + _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", + self), } } From 5fa7edbcd6816df80462f478c6f4e5d1c692e2ad Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 21:13:04 +0100 Subject: [PATCH 007/281] ValueHandle::{is_compound,num_fields,index} => {variant_values, as_sum} --- hugr-passes/src/const_fold2/partial_value.rs | 30 ++++-------- .../const_fold2/partial_value/value_handle.rs | 47 +++++++------------ 2 files changed, 27 insertions(+), 50 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index b5018ce38..2e01108e5 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -1,11 +1,9 @@ #![allow(missing_docs)] +use itertools::{zip_eq, Itertools as _}; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use hugr_core::ops::constant::Sum; -use itertools::{zip_eq, Itertools as _}; - use hugr_core::ops::Value; use hugr_core::types::{Type, TypeEnum, TypeRow}; @@ -22,7 +20,7 @@ impl PartialSum { Self::variant(0, []) } pub fn variant(tag: usize, values: impl IntoIterator) -> Self { - Self([(tag, values.into_iter().collect())].into_iter().collect()) + Self(HashMap::from([(tag, Vec::from_iter(values))])) } pub fn num_variants(&self) -> usize { @@ -171,16 +169,10 @@ impl TryFrom for PartialSum { type Error = ValueHandle; fn try_from(value: ValueHandle) -> Result { - match value.value() { - Value::Sum(Sum { tag, values, .. }) => { - let vec = (0..values.len()) - .map(|i| PartialValue::from(value.index(i)).into()) - .collect(); - return Ok(Self([(*tag, vec)].into_iter().collect())); - } - _ => (), - }; - Err(value) + value + .as_sum() + .map(|(tag, values)| Self::variant(tag, values.into_iter().map(PartialValue::from))) + .ok_or(value) } } @@ -421,13 +413,9 @@ impl PartialValue { match self { Self::Bottom => Self::Bottom, Self::PartialSum(ps) => ps.variant_field_value(variant, idx), - Self::Value(v) => { - if v.tag() == variant { - Self::Value(v.index(idx)) - } else { - Self::Bottom - } - } + Self::Value(v) => v + .variant_values(variant) + .map_or(Self::Bottom, |vals| Self::Value(vals[idx].clone())), Self::Top => Self::Top, } } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index ff3a1fa16..ae5facbc5 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -78,44 +78,32 @@ impl ValueHandle { self.1.as_ref() } - pub fn is_compound(&self) -> bool { - matches!(self.value(), Value::Sum(_)) + pub fn variant_values(&self, variant: usize) -> Option> { + self.as_sum() + .and_then(|(tag, vals)| (tag == variant).then_some(vals)) } - pub fn num_fields(&self) -> usize { - assert!( - self.is_compound(), - "ValueHandle::num_fields called on non-Sum, non-Tuple value: {:#?}", - self - ); + pub fn as_sum(&self) -> Option<(usize, Vec)> { match self.value() { - Value::Sum(Sum { values, .. }) => values.len(), - _ => unreachable!(), + Value::Sum(Sum { tag, values, .. }) => { + let vals = values.iter().cloned().map(Arc::new); + let keys = (0..).map(|i| self.0.clone().index(i)); + let vec = keys.zip(vals).map(|(i, v)| Self(i, v)).collect(); + Some((*tag, vec)) + } + _ => None, } } pub fn tag(&self) -> usize { match self.value() { Value::Sum(Sum { tag, .. }) => *tag, - _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", - self), + _ => panic!( + "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", + self + ), } } - - pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { - assert!( - i < self.num_fields(), - "ValueHandle::index called with out-of-bounds index {}: {:#?}", - i, - &self - ); - let vs = match self.value() { - Value::Sum(Sum { values, .. }) => values, - _ => unreachable!(), - }; - let v = vs[i].clone().into(); - Self(self.0.clone().index(i), v) - } } impl PartialEq for ValueHandle { @@ -238,8 +226,9 @@ mod test { let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); + let (_, fields) = v1.as_sum().unwrap(); // we do not compare the value, just the key - assert_ne!(v1.index(0), v2); - assert_eq!(v1.index(0).value(), v2.value()); + assert_ne!(fields[0], v2); + assert_eq!(fields[0].value(), v2.value()); } } From 98bf94a3b35a5a9f422f03d328900b6224086ff9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 21:16:46 +0100 Subject: [PATCH 008/281] Rm ValueHandle::tag, use variant_values - inefficient, presume this is what was meant --- hugr-passes/src/const_fold2/partial_value.rs | 2 +- .../src/const_fold2/partial_value/value_handle.rs | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 2e01108e5..0c7c5b4f2 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -397,7 +397,7 @@ impl PartialValue { pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, - PartialValue::Value(v) => v.tag() == tag, // can never be a sum or tuple + PartialValue::Value(v) => v.variant_values(tag).is_some(), PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index ae5facbc5..728caeb33 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -94,16 +94,6 @@ impl ValueHandle { _ => None, } } - - pub fn tag(&self) -> usize { - match self.value() { - Value::Sum(Sum { tag, .. }) => *tag, - _ => panic!( - "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", - self - ), - } - } } impl PartialEq for ValueHandle { From 295ec3277e180e141ae2a5fded88d894f3ce8848 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:09:18 +0100 Subject: [PATCH 009/281] add variant_values, rewrite one use of outputs_for_variant --- hugr-passes/src/const_fold2/datalog.rs | 9 +++------ hugr-passes/src/const_fold2/datalog/utils.rs | 16 ++++++++++++++- hugr-passes/src/const_fold2/partial_value.rs | 21 ++++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 96c4dd50c..1cf85f18a 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -115,12 +115,9 @@ ascent::ascent! { io_node(c,tl_n,in_n, IO::Input), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - let variant_len = tailloop.just_inputs.len(), - for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map( - |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) - ); + if let Some(fields) = out_in_row[0].variant_values(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 + for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields.into_iter().chain(out_in_row.iter().skip(1).cloned())); // Output node of child region propagate to outputs of tail loop out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), @@ -129,7 +126,7 @@ ascent::ascent! { if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_outputs.len(), - for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map( + for (out_p, v) in out_in_row.iter_with_ports(c.hugr(), *out_n).flat_map( |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) ); diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 5c2b12730..05ceccf16 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -30,6 +30,16 @@ impl PV { self.variant_field_value(0, idx) } + pub fn variant_values(&self, variant: usize, len: usize) -> Option> { + Some( + self.0 + .variant_values(variant, len)? + .into_iter() + .map(PV::from) + .collect(), + ) + } + /// TODO the arguments here are not pretty, two usizes, better not mix them /// up!!! pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { @@ -104,7 +114,11 @@ impl ValueRow { Self::new(r.len()) } - pub fn iter<'b>( + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn iter_with_ports<'b>( &'b self, h: &'b impl HugrView, n: Node, diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 0c7c5b4f2..e7e532fdf 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -34,6 +34,12 @@ impl PartialSum { } } + pub fn variant_values(&self, variant: usize, len: usize) -> Option> { + let row = self.0.get(&variant)?; + assert!(row.len() == len); + Some(row.clone()) + } + pub fn variant_field_value(&self, variant: usize, idx: usize) -> PartialValue { if let Some(row) = self.0.get(&variant) { assert!(row.len() > idx); @@ -394,6 +400,21 @@ impl PartialValue { Self::variant(0, []) } + pub fn variant_values(&self, tag: usize, len: usize) -> Option> { + let vals = match self { + PartialValue::Bottom => return None, + PartialValue::Value(v) => v + .variant_values(tag)? + .into_iter() + .map(PartialValue::Value) + .collect(), + PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, + PartialValue::Top => vec![PartialValue::Top; len], + }; + assert_eq!(vals.len(), len); + Some(vals) + } + pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, From 5c8289e88f94c8fac2c3a4755f768bd5d89a97ce Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:18:52 +0100 Subject: [PATCH 010/281] ...and the other two; remove outputs_for_variant --- hugr-passes/src/const_fold2/datalog.rs | 17 ++++---- hugr-passes/src/const_fold2/datalog/utils.rs | 42 -------------------- 2 files changed, 9 insertions(+), 50 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 1cf85f18a..71dee7483 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -117,17 +117,17 @@ ascent::ascent! { node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), if let Some(fields) = out_in_row[0].variant_values(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 - for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields.into_iter().chain(out_in_row.iter().skip(1).cloned())); + for (out_p, v) in (0..).map(OutgoingPort::from).zip( + fields.into_iter().chain(out_in_row.iter().skip(1).cloned())); // Output node of child region propagate to outputs of tail loop out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - let variant_len = tailloop.just_outputs.len(), - for (out_p, v) in out_in_row.iter_with_ports(c.hugr(), *out_n).flat_map( - |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) + if let Some(fields) = out_in_row[0].variant_values(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 + for (out_p, v) in (0..).map(OutgoingPort::from).zip( + fields.into_iter().chain(out_in_row.iter().skip(1).cloned()) ); lattice tail_loop_termination(C,Node,TailLoopTermination); @@ -152,10 +152,11 @@ ascent::ascent! { out_wire_value(c, i_node, i_p, v) <-- case_node(c, cond, case_index, case), io_node(c, case, i_node, IO::Input), - in_wire_value(c, cond, cond_in_p, cond_in_v), + node_in_value_row(c, cond, in_row), + //in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.get_optype(*cond).as_conditional(), - let variant_len = conditional.sum_rows[*case_index].len(), - for (i_p, v) in utils::outputs_for_variant(*cond_in_p, *case_index, variant_len, cond_in_v); + if let Some(fields) = in_row[0].variant_values(*case_index, conditional.sum_rows[*case_index].len()), + for (i_p, v) in (0..).map(OutgoingPort::from).zip(fields.into_iter().chain(in_row.iter().skip(1).cloned())); // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 05ceccf16..d41e4bceb 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -233,48 +233,6 @@ pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator( - output_p: IncomingPort, - variant_tag: usize, - variant_len: usize, - v: &'a PV, -) -> impl Iterator + 'a { - if output_p.index() == 0 { - Either::Left( - (0..variant_len).map(move |i| (i.into(), v.variant_field_value(variant_tag, i))), - ) - } else { - let v = if v.supports_tag(variant_tag) { - v.clone() - } else { - PV::bottom() - }; - Either::Right(std::iter::once(( - (variant_len + output_p.index() - 1).into(), - v, - ))) - } -} - #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] #[cfg_attr(test, derive(Arbitrary))] pub enum TailLoopTermination { From bf173ab4ae14279c0c32dee70f6cd80bda128ec1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:26:22 +0100 Subject: [PATCH 011/281] Rewrite tuple rule to avoid indexing --- hugr-passes/src/const_fold2/datalog.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 71dee7483..0fc56241c 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -85,10 +85,11 @@ ascent::ascent! { relation unpack_tuple_node(C, Node); unpack_tuple_node(c,n) <-- node(c, n), if c.get_optype(*n).is_unpack_tuple(); - out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- + out_wire_value(c, n, p, v) <-- unpack_tuple_node(c, n), - in_wire_value(c, n, IncomingPort::from(0), v), - out_wire(c, n, p); + in_wire_value(c, n, IncomingPort::from(0), tup), + if let Some(fields) = tup.variant_values(0, utils::value_outputs(c.hugr(),*n).count()), + for (p,v) in (0..).map(OutgoingPort::from).zip(fields); // DFG From 0ae4d196046d99c9ceedec904c06f54328e00eff Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:28:44 +0100 Subject: [PATCH 012/281] GC unused (tuple,variant)_field_value, iter_with_ports --- hugr-passes/src/const_fold2/datalog/utils.rs | 18 ------------- hugr-passes/src/const_fold2/partial_value.rs | 27 +------------------- 2 files changed, 1 insertion(+), 44 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index d41e4bceb..bebb741f8 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -26,10 +26,6 @@ impl From for PV { } impl PV { - pub fn tuple_field_value(&self, idx: usize) -> Self { - self.variant_field_value(0, idx) - } - pub fn variant_values(&self, variant: usize, len: usize) -> Option> { Some( self.0 @@ -40,12 +36,6 @@ impl PV { ) } - /// TODO the arguments here are not pretty, two usizes, better not mix them - /// up!!! - pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { - self.0.variant_field_value(variant, idx).into() - } - pub fn supports_tag(&self, tag: usize) -> bool { self.0.supports_tag(tag) } @@ -118,14 +108,6 @@ impl ValueRow { self.0.iter() } - pub fn iter_with_ports<'b>( - &'b self, - h: &'b impl HugrView, - n: Node, - ) -> impl Iterator + 'b { - zip_eq(value_inputs(h, n), self.0.iter()) - } - // fn initialised(&self) -> bool { // self.0.iter().all(|x| x != &PV::top()) // } diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index e7e532fdf..4337e5bb6 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -40,15 +40,6 @@ impl PartialSum { Some(row.clone()) } - pub fn variant_field_value(&self, variant: usize, idx: usize) -> PartialValue { - if let Some(row) = self.0.get(&variant) { - assert!(row.len() > idx); - row[idx].clone() - } else { - PartialValue::bottom() - } - } - pub fn try_into_value(self, typ: &Type) -> Result { let Ok((k, v)) = self.0.iter().exactly_one() else { Err(self)? @@ -418,28 +409,12 @@ impl PartialValue { pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, + // TODO this is wildly expensive - only used for case reachability but still... PartialValue::Value(v) => v.variant_values(tag).is_some(), PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } } - - /// TODO docs - pub fn tuple_field_value(&self, idx: usize) -> Self { - self.variant_field_value(0, idx) - } - - /// TODO docs - pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { - match self { - Self::Bottom => Self::Bottom, - Self::PartialSum(ps) => ps.variant_field_value(variant, idx), - Self::Value(v) => v - .variant_values(variant) - .map_or(Self::Bottom, |vals| Self::Value(vals[idx].clone())), - Self::Top => Self::Top, - } - } } impl PartialOrd for PartialValue { From 8608ba9a9a7e2400f964e878207956f5b23d029e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:33:53 +0100 Subject: [PATCH 013/281] Common up via ValueRow.unpack_first --- hugr-passes/src/const_fold2/datalog.rs | 15 ++++++--------- hugr-passes/src/const_fold2/datalog/utils.rs | 10 ++++++++++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 0fc56241c..2944503c7 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -117,19 +117,16 @@ ascent::ascent! { io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - if let Some(fields) = out_in_row[0].variant_values(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 - for (out_p, v) in (0..).map(OutgoingPort::from).zip( - fields.into_iter().chain(out_in_row.iter().skip(1).cloned())); + if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 + for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); // Output node of child region propagate to outputs of tail loop out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - if let Some(fields) = out_in_row[0].variant_values(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 - for (out_p, v) in (0..).map(OutgoingPort::from).zip( - fields.into_iter().chain(out_in_row.iter().skip(1).cloned()) - ); + if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 + for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); lattice tail_loop_termination(C,Node,TailLoopTermination); tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- @@ -156,8 +153,8 @@ ascent::ascent! { node_in_value_row(c, cond, in_row), //in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.get_optype(*cond).as_conditional(), - if let Some(fields) = in_row[0].variant_values(*case_index, conditional.sum_rows[*case_index].len()), - for (i_p, v) in (0..).map(OutgoingPort::from).zip(fields.into_iter().chain(in_row.iter().skip(1).cloned())); + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + for (i_p, v) in (0..).map(OutgoingPort::from).zip(fields); // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index bebb741f8..5a9ac8495 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -108,6 +108,16 @@ impl ValueRow { self.0.iter() } + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option + '_> { + self[0] + .variant_values(variant, len) + .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) + } + // fn initialised(&self) -> bool { // self.0.iter().all(|x| x != &PV::top()) // } From 2dca3e9fcb06d73cf7c6c8728d11a5a80412433a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 19:25:23 +0100 Subject: [PATCH 014/281] No DeRef for ValueHandle, just add get_type() --- .../const_fold2/partial_value/value_handle.rs | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 728caeb33..6a5d9dd81 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -1,10 +1,10 @@ use std::hash::{DefaultHasher, Hash, Hasher}; -use std::ops::Deref; use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::Value; +use hugr_core::types::Type; use hugr_core::Node; #[derive(Clone, Debug)] @@ -94,6 +94,10 @@ impl ValueHandle { _ => None, } } + + pub fn get_type(&self) -> Type { + self.1.get_type() + } } impl PartialEq for ValueHandle { @@ -119,18 +123,6 @@ impl Hash for ValueHandle { } } -/// TODO this is perhaps dodgy -/// we do not hash or compare the value, just the key -/// this means two handles with different keys, but with the same value, will -/// not compare equal. -impl Deref for ValueHandle { - type Target = Value; - - fn deref(&self) -> &Self::Target { - self.value() - } -} - #[cfg(test)] mod test { use hugr_core::{ From 51e68ea99a44addad85d4aaee18d425b1d52829d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 19:27:22 +0100 Subject: [PATCH 015/281] ValueKey::{Select->Field,index->field} --- .../src/const_fold2/partial_value/value_handle.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 6a5d9dd81..3b450c178 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -29,7 +29,7 @@ impl Hash for HashedConst { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum ValueKey { - Select(usize, Box), + Field(usize, Box), Const(HashedConst), Node(Node), } @@ -61,8 +61,8 @@ impl ValueKey { }) } - pub fn index(self, i: usize) -> Self { - Self::Select(i, Box::new(self)) + pub fn field(self, i: usize) -> Self { + Self::Field(i, Box::new(self)) } } @@ -87,7 +87,7 @@ impl ValueHandle { match self.value() { Value::Sum(Sum { tag, values, .. }) => { let vals = values.iter().cloned().map(Arc::new); - let keys = (0..).map(|i| self.0.clone().index(i)); + let keys = (0..).map(|i| self.0.clone().field(i)); let vec = keys.zip(vals).map(|(i, v)| Self(i, v)).collect(); Some((*tag, vec)) } @@ -163,12 +163,12 @@ mod test { assert_eq!(&k4, &k5); assert_ne!(&k4, &k6); - let k7 = k5.clone().index(3); - let k4 = k4.index(3); + let k7 = k5.clone().field(3); + let k4 = k4.field(3); assert_eq!(&k4, &k7); - let k5 = k5.index(2); + let k5 = k5.field(2); assert_ne!(&k5, &k7); } From 863547413877b7ca3e305453da1726ef30a5b792 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 7 Aug 2024 11:24:02 +0100 Subject: [PATCH 016/281] (join/meet)_mut_unsafe => try_(join/meet)_mut with Err for conflicting len --- hugr-passes/src/const_fold2/partial_value.rs | 43 +++++++++++++++----- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 4337e5bb6..dc0ea005b 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -66,12 +66,16 @@ impl PartialSum { } } - // unsafe because we panic if any common rows have different lengths - fn join_mut_unsafe(&mut self, other: Self) -> bool { + // Err with key if any common rows have different lengths (self may have been mutated) + fn try_join_mut(&mut self, other: Self) -> Result { let mut changed = false; for (k, v) in other.0 { if let Some(row) = self.0.get_mut(&k) { + if v.len() != row.len() { + // Better to check first and avoid mutation, but fine here + return Err(k); + } for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { changed |= lhs.join_mut(rhs); } @@ -80,16 +84,21 @@ impl PartialSum { changed = true; } } - changed + Ok(changed) } - // unsafe because we panic if any common rows have different lengths - fn meet_mut_unsafe(&mut self, other: Self) -> bool { + // Error with key if any common rows have different lengths ( => Bottom) + fn try_meet_mut(&mut self, other: Self) -> Result { let mut changed = false; let mut keys_to_remove = vec![]; - for k in self.0.keys() { - if !other.0.contains_key(k) { - keys_to_remove.push(*k); + for (k, v) in self.0.iter() { + match other.0.get(k) { + None => keys_to_remove.push(*k), + Some(o_v) => { + if v.len() != o_v.len() { + return Err(*k); + } + } } } for (k, v) in other.0 { @@ -105,7 +114,7 @@ impl PartialSum { self.0.remove(&k); changed = true; } - changed + Ok(changed) } pub fn supports_tag(&self, tag: usize) -> bool { @@ -304,7 +313,13 @@ impl PartialValue { let Self::PartialSum(ps1) = self else { unreachable!() }; - ps1.join_mut_unsafe(ps2) + match ps1.try_join_mut(ps2) { + Ok(ch) => ch, + Err(_) => { + *self = Self::Top; + true + } + } } (Self::Value(_), mut other) => { std::mem::swap(self, &mut other); @@ -354,7 +369,13 @@ impl PartialValue { let Self::PartialSum(ps1) = self else { unreachable!() }; - ps1.meet_mut_unsafe(ps2) + match ps1.try_meet_mut(ps2) { + Ok(ch) => ch, + Err(_) => { + *self = Self::Bottom; + true + } + } } (Self::Value(_), mut other @ Self::PartialSum(_)) => { std::mem::swap(self, &mut other); From 80d5b866903e06e1aa3cc341cf4309e409b1bbc5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 8 Aug 2024 15:02:03 +0100 Subject: [PATCH 017/281] Remove ValueHandle::variant_values - just have as_sum --- hugr-passes/src/const_fold2/partial_value.rs | 6 ++++-- hugr-passes/src/const_fold2/partial_value/value_handle.rs | 5 ----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index dc0ea005b..4bb56222d 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -416,7 +416,9 @@ impl PartialValue { let vals = match self { PartialValue::Bottom => return None, PartialValue::Value(v) => v - .variant_values(tag)? + .as_sum() + .filter(|(variant, _)| tag == *variant)? + .1 .into_iter() .map(PartialValue::Value) .collect(), @@ -431,7 +433,7 @@ impl PartialValue { match self { PartialValue::Bottom => false, // TODO this is wildly expensive - only used for case reachability but still... - PartialValue::Value(v) => v.variant_values(tag).is_some(), + PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 3b450c178..6a4d70a60 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -78,11 +78,6 @@ impl ValueHandle { self.1.as_ref() } - pub fn variant_values(&self, variant: usize) -> Option> { - self.as_sum() - .and_then(|(tag, vals)| (tag == variant).then_some(vals)) - } - pub fn as_sum(&self) -> Option<(usize, Vec)> { match self.value() { Value::Sum(Sum { tag, values, .. }) => { From 1c8be9989380588ed4c48c6220bc2c05d5fa4101 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 8 Aug 2024 15:06:25 +0100 Subject: [PATCH 018/281] Optimize as_sum() by returning impl Iterator not Vec --- hugr-passes/src/const_fold2/partial_value.rs | 4 +--- .../const_fold2/partial_value/value_handle.rs | 17 +++++++++-------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 4bb56222d..3ba8e7c57 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -177,7 +177,7 @@ impl TryFrom for PartialSum { fn try_from(value: ValueHandle) -> Result { value .as_sum() - .map(|(tag, values)| Self::variant(tag, values.into_iter().map(PartialValue::from))) + .map(|(tag, values)| Self::variant(tag, values.map(PartialValue::from))) .ok_or(value) } } @@ -419,7 +419,6 @@ impl PartialValue { .as_sum() .filter(|(variant, _)| tag == *variant)? .1 - .into_iter() .map(PartialValue::Value) .collect(), PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, @@ -432,7 +431,6 @@ impl PartialValue { pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, - // TODO this is wildly expensive - only used for case reachability but still... PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 6a4d70a60..147048ae7 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -78,14 +78,15 @@ impl ValueHandle { self.1.as_ref() } - pub fn as_sum(&self) -> Option<(usize, Vec)> { + pub fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { match self.value() { - Value::Sum(Sum { tag, values, .. }) => { - let vals = values.iter().cloned().map(Arc::new); - let keys = (0..).map(|i| self.0.clone().field(i)); - let vec = keys.zip(vals).map(|(i, v)| Self(i, v)).collect(); - Some((*tag, vec)) - } + Value::Sum(Sum { tag, values, .. }) => Some(( + *tag, + values + .iter() + .enumerate() + .map(|(i, v)| Self(self.0.clone().field(i), Arc::new(v.clone()))), + )), _ => None, } } @@ -203,7 +204,7 @@ mod test { let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); - let (_, fields) = v1.as_sum().unwrap(); + let fields = v1.as_sum().unwrap().1.collect::>(); // we do not compare the value, just the key assert_ne!(fields[0], v2); assert_eq!(fields[0].value(), v2.value()); From b0afa54aab5c94cca1e84b44a4734c6aaee5db94 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 27 Aug 2024 17:27:09 +0100 Subject: [PATCH 019/281] Machine uses PV not PartialValue --- hugr-passes/src/const_fold2/datalog.rs | 10 +++------- hugr-passes/src/const_fold2/datalog/test.rs | 6 +++--- hugr-passes/src/const_fold2/datalog/utils.rs | 8 +++++++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 2944503c7..04a63a608 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -2,7 +2,6 @@ use ascent::lattice::BoundedLattice; use std::collections::HashMap; use std::hash::Hash; -use super::partial_value::PartialValue; use hugr_core::ops::Value; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; @@ -170,10 +169,7 @@ ascent::ascent! { } // TODO This should probably be called 'Analyser' or something -struct Machine( - AscentProgram>, - Option>, -); +struct Machine(AscentProgram>, Option>); /// Usage: /// 1. [Self::new()] @@ -185,7 +181,7 @@ impl Machine { Self(Default::default(), None) } - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { assert!(self.1.is_none()); self.0.out_wire_value_proto.extend( wires @@ -207,7 +203,7 @@ impl Machine { ) } - pub fn read_out_wire_partial_value(&self, w: Wire) -> Option { + pub fn read_out_wire_partial_value(&self, w: Wire) -> Option { self.1.as_ref().unwrap().get(&w).cloned() } diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 783171525..2f3ad5d5d 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -118,9 +118,9 @@ fn test_tail_loop_always_iterates() { machine.run_hugr(&hugr); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); - assert_eq!(o_r1, PartialValue::bottom()); + assert_eq!(o_r1, PartialValue::bottom().into()); let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); - assert_eq!(o_r2, PartialValue::bottom()); + assert_eq!(o_r2, PartialValue::bottom().into()); assert_eq!( TailLoopTermination::bottom(), machine.tail_loop_terminates(&hugr, tail_loop.node()) @@ -220,7 +220,7 @@ fn conditional() { let mut machine = Machine::new(); let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); - machine.propolutate_out_wires([(arg_w, arg_pv)]); + machine.propolutate_out_wires([(arg_w, arg_pv.into())]); machine.run_hugr(&hugr); let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 5a9ac8495..e10e96ed3 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -10,7 +10,9 @@ use itertools::{zip_eq, Either}; use crate::const_fold2::partial_value::{PartialValue, ValueHandle}; use hugr_core::{ - ops::OpTrait as _, types::TypeRow, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, + ops::{OpTrait as _, Value}, + types::{Type, TypeRow}, + HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; #[cfg(test)] @@ -39,6 +41,10 @@ impl PV { pub fn supports_tag(&self, tag: usize) -> bool { self.0.supports_tag(tag) } + + pub fn try_into_value(self, ty: &Type) -> Result { + self.0.try_into_value(ty).map_err(Self) + } } impl From for PartialValue { From d09a1fe769e477d82533e1b42ba9d6a00f0d90bb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 12:15:40 +0100 Subject: [PATCH 020/281] Parametrize PartialValue+PV+Machine by AbstractValue/Into, Context interprets load_constant --- hugr-passes/src/const_fold2.rs | 5 +- hugr-passes/src/const_fold2/datalog.rs | 61 ++++--- .../src/const_fold2/datalog/context.rs | 15 +- hugr-passes/src/const_fold2/datalog/test.rs | 18 +- hugr-passes/src/const_fold2/datalog/utils.rs | 110 ++++++------ hugr-passes/src/const_fold2/partial_value.rs | 161 +++++++++--------- .../src/const_fold2/partial_value/test.rs | 20 ++- .../{partial_value => }/value_handle.rs | 21 ++- 8 files changed, 227 insertions(+), 184 deletions(-) rename hugr-passes/src/const_fold2/{partial_value => }/value_handle.rs (95%) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 96af004e1..13af5c709 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,2 +1,3 @@ -mod datalog; -pub mod partial_value; +pub mod datalog; +mod partial_value; +pub mod value_handle; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 04a63a608..c06b2a285 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -8,11 +8,13 @@ use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _ mod context; mod utils; -use context::DataflowContext; pub use utils::{TailLoopTermination, ValueRow, IO, PV}; -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { +use super::partial_value::AbstractValue; + +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; + fn value_from_load_constant(&self, node: Node) -> V; } ascent::ascent! { @@ -20,18 +22,18 @@ ascent::ascent! { // DataflowContext (for H: HugrView) would be sufficient, there's really no // point in using anything else yet. However DFContext will be useful when we // move interpretation of nodes out into a trait method. - struct AscentProgram; + struct AscentProgram>; relation context(C); - relation out_wire_value_proto(Node, OutgoingPort, PV); + relation out_wire_value_proto(Node, OutgoingPort, PV); relation node(C, Node); relation in_wire(C, Node, IncomingPort); relation out_wire(C, Node, OutgoingPort); relation parent_of_node(C, Node, Node); relation io_node(C, Node, Node, IO); - lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); - lattice in_wire_value(C, Node, IncomingPort, PV); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); node(c, n) <-- context(c), for n in c.nodes(); @@ -68,7 +70,7 @@ ascent::ascent! { relation load_constant_node(C, Node); load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); - out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c.hugr(), *n)) <-- + out_wire_value(c, n, 0.into(), PV::from(c.value_from_load_constant(*n))) <-- load_constant_node(c, n); @@ -169,19 +171,22 @@ ascent::ascent! { } // TODO This should probably be called 'Analyser' or something -struct Machine(AscentProgram>, Option>); +pub struct Machine>( + AscentProgram, + Option>>, +); /// Usage: /// 1. [Self::new()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run_hugr] to do the analysis /// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] -impl Machine { +impl> Machine { pub fn new() -> Self { Self(Default::default(), None) } - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator)>) { assert!(self.1.is_none()); self.0.out_wire_value_proto.extend( wires @@ -190,9 +195,9 @@ impl Machine { ); } - pub fn run_hugr(&mut self, hugr: H) { + pub fn run(&mut self, context: C) { assert!(self.1.is_none()); - self.0.context.push((DataflowContext::new(hugr),)); + self.0.context.push((context,)); self.0.run(); self.1 = Some( self.0 @@ -203,22 +208,11 @@ impl Machine { ) } - pub fn read_out_wire_partial_value(&self, w: Wire) -> Option { + pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { self.1.as_ref().unwrap().get(&w).cloned() } - pub fn read_out_wire_value(&self, hugr: H, w: Wire) -> Option { - // dbg!(&w); - let pv = self.read_out_wire_partial_value(w)?; - // dbg!(&pv); - let (_, typ) = hugr - .out_value_types(w.node()) - .find(|(p, _)| *p == w.source()) - .unwrap(); - pv.try_into_value(&typ).ok() - } - - pub fn tail_loop_terminates(&self, hugr: H, node: Node) -> TailLoopTermination { + pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { assert!(hugr.get_optype(node).is_tail_loop()); self.0 .tail_loop_termination @@ -227,7 +221,7 @@ impl Machine { .unwrap() } - pub fn case_reachable(&self, hugr: H, case: Node) -> bool { + pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { assert!(hugr.get_optype(case).is_case()); let cond = hugr.get_parent(case).unwrap(); assert!(hugr.get_optype(cond).is_conditional()); @@ -239,5 +233,18 @@ impl Machine { } } +impl, C: DFContext> Machine { + pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { + // dbg!(&w); + let pv = self.read_out_wire_partial_value(w)?; + // dbg!(&pv); + let (_, typ) = hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + pv.try_into_value(&typ).ok() + } +} + #[cfg(test)] mod test; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 1d77e39eb..31a3233fd 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -4,6 +4,8 @@ use std::sync::Arc; use hugr_core::{Hugr, HugrView}; +use crate::const_fold2::value_handle::ValueHandle; + use super::DFContext; #[derive(Debug)] @@ -53,8 +55,19 @@ impl Deref for DataflowContext { } } -impl DFContext for DataflowContext { +impl DFContext for DataflowContext { fn hugr(&self) -> &impl HugrView { self.0.as_ref() } + + fn value_from_load_constant(&self, node: hugr_core::Node) -> ValueHandle { + let load_op = self.0.get_optype(node).as_load_constant().unwrap(); + let const_node = self + .0 + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = self.0.get_optype(const_node).as_const().unwrap(); + ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())) + } } diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 2f3ad5d5d..f80cc903e 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -1,3 +1,4 @@ +use context::DataflowContext; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, @@ -6,8 +7,7 @@ use hugr_core::{ types::{Signature, SumType, Type, TypeRow}, }; -use crate::const_fold2::partial_value::PartialValue; - +use super::super::partial_value::PartialValue; use super::*; #[test] @@ -19,7 +19,7 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let x = machine.read_out_wire_value(&hugr, v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -38,7 +38,7 @@ fn test_unpack_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -57,7 +57,7 @@ fn test_unpack_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); assert_eq!(o_r, Value::true_val()); @@ -83,7 +83,7 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -115,7 +115,7 @@ fn test_tail_loop_always_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom().into()); @@ -165,7 +165,7 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -221,7 +221,7 @@ fn conditional() { let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); machine.propolutate_out_wires([(arg_w, arg_pv.into())]); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index e10e96ed3..42138396e 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -3,32 +3,40 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -use std::{cmp::Ordering, ops::Index, sync::Arc}; +use std::{cmp::Ordering, ops::Index}; use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::{zip_eq, Either}; +use itertools::zip_eq; -use crate::const_fold2::partial_value::{PartialValue, ValueHandle}; +use crate::const_fold2::partial_value::{AbstractValue, PartialValue}; use hugr_core::{ ops::{OpTrait as _, Value}, - types::{Type, TypeRow}, + types::{Signature, Type, TypeRow}, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; #[cfg(test)] use proptest_derive::Arbitrary; -#[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] -pub struct PV(PartialValue); +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +pub struct PV(PartialValue); -impl From for PV { - fn from(inner: PartialValue) -> Self { +// Implement manually as PartialValue is PartialOrd even when V isn't +// (deriving PartialOrd conditions on V: PartialOrd, which is not necessary) +impl PartialOrd for PV { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl From> for PV { + fn from(inner: PartialValue) -> Self { Self(inner) } } -impl PV { - pub fn variant_values(&self, variant: usize, len: usize) -> Option> { +impl PV { + pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { Some( self.0 .variant_values(variant, len)? @@ -41,25 +49,27 @@ impl PV { pub fn supports_tag(&self, tag: usize) -> bool { self.0.supports_tag(tag) } +} +impl> PV { pub fn try_into_value(self, ty: &Type) -> Result { self.0.try_into_value(ty).map_err(Self) } } -impl From for PartialValue { - fn from(value: PV) -> Self { +impl From> for PartialValue { + fn from(value: PV) -> Self { value.0 } } -impl From for PV { - fn from(inner: ValueHandle) -> Self { +impl From for PV { + fn from(inner: V) -> Self { Self(inner.into()) } } -impl Lattice for PV { +impl Lattice for PV { fn meet(self, other: Self) -> Self { self.0.meet(other.0).into() } @@ -77,7 +87,7 @@ impl Lattice for PV { } } -impl BoundedLattice for PV { +impl BoundedLattice for PV { fn bottom() -> Self { PartialValue::bottom().into() } @@ -87,22 +97,22 @@ impl BoundedLattice for PV { } } -#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] -pub struct ValueRow(Vec); +#[derive(PartialEq, Clone, Eq, Hash)] +pub struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { fn new(len: usize) -> Self { Self(vec![PV::bottom(); len]) } - fn singleton(len: usize, idx: usize, v: PV) -> Self { + fn singleton(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { + fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { Self::singleton(r.len(), idx, v) } @@ -110,7 +120,7 @@ impl ValueRow { Self::new(r.len()) } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator> { self.0.iter() } @@ -118,7 +128,7 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option + '_> { + ) -> Option> + '_> { self[0] .variant_values(variant, len) .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) @@ -129,7 +139,13 @@ impl ValueRow { // } } -impl Lattice for ValueRow { +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { fn meet(mut self, other: Self) -> Self { self.meet_mut(other); self @@ -159,36 +175,42 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PV; +impl IntoIterator for ValueRow { + type Item = PV; - type IntoIter = as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec: Index, + Vec>: Index, { - type Output = as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { - if let Some(sig) = h.signature(n) { - ValueRow::new(sig.input_count()) - } else { - ValueRow::new(0) - } +pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { + ValueRow::new( + h.signature(n) + .as_ref() + .map(Signature::input_count) + .unwrap_or(0), + ) } -pub(super) fn singleton_in_row(h: &impl HugrView, n: &Node, ip: &IncomingPort, v: PV) -> ValueRow { +pub(super) fn singleton_in_row( + h: &impl HugrView, + n: &Node, + ip: &IncomingPort, + v: PV, +) -> ValueRow { let Some(sig) = h.signature(*n) else { panic!("dougrulz"); }; @@ -203,17 +225,7 @@ pub(super) fn singleton_in_row(h: &impl HugrView, n: &Node, ip: &IncomingPort, v ValueRow::singleton_from_row(&h.signature(*n).unwrap().input, ip.index(), v) } -pub(super) fn partial_value_from_load_constant(h: &impl HugrView, node: Node) -> PV { - let load_op = h.get_optype(node).as_load_constant().unwrap(); - let const_node = h - .single_linked_output(node, load_op.constant_port()) - .unwrap() - .0; - let const_op = h.get_optype(const_node).as_const().unwrap(); - ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())).into() -} - -pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { +pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() } @@ -240,7 +252,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - pub fn from_control_value(v: &PV) -> Self { + pub fn from_control_value(v: &PV) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break && !may_continue { Self::ExactlyZeroContinues diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 3ba8e7c57..6a2a614a8 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -1,32 +1,34 @@ #![allow(missing_docs)] -use itertools::{zip_eq, Itertools as _}; -use std::cmp::Ordering; -use std::collections::HashMap; -use std::hash::{Hash, Hasher}; use hugr_core::ops::Value; use hugr_core::types::{Type, TypeEnum, TypeRow}; +use itertools::{zip_eq, Itertools}; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; -mod value_handle; - -pub use value_handle::{ValueHandle, ValueKey}; +pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; +} -// TODO ALAN inline into PartialValue +// TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum #[derive(PartialEq, Clone, Eq)] -struct PartialSum(HashMap>); +pub struct PartialSum(pub HashMap>>); -impl PartialSum { +impl PartialSum { pub fn unit() -> Self { Self::variant(0, []) } - pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + pub fn variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } pub fn num_variants(&self) -> usize { self.0.len() } +} +impl PartialSum { fn assert_variants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { @@ -34,38 +36,6 @@ impl PartialSum { } } - pub fn variant_values(&self, variant: usize, len: usize) -> Option> { - let row = self.0.get(&variant)?; - assert!(row.len() == len); - Some(row.clone()) - } - - pub fn try_into_value(self, typ: &Type) -> Result { - let Ok((k, v)) = self.0.iter().exactly_one() else { - Err(self)? - }; - - let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(self)? - }; - let Some(r) = st.get_variant(*k) else { - Err(self)? - }; - let Ok(r): Result = r.clone().try_into() else { - Err(self)? - }; - if v.len() != r.len() { - return Err(self); - } - match zip_eq(v.into_iter(), r.into_iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - { - Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), - Err(_) => Err(self), - } - } - // Err with key if any common rows have different lengths (self may have been mutated) fn try_join_mut(&mut self, other: Self) -> Result { let mut changed = false; @@ -122,7 +92,43 @@ impl PartialSum { } } -impl PartialOrd for PartialSum { +impl> PartialSum { + pub fn try_into_value(self, typ: &Type) -> Result { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + let Ok(r): Result = r.clone().try_into() else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), + Err(_) => Err(self), + } + } +} + +impl PartialSum { + pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { + let row = self.0.get(&variant)?; + assert!(row.len() == len); + Some(row.clone()) + } +} + +impl PartialOrd for PartialSum { fn partial_cmp(&self, other: &Self) -> Option { let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); @@ -156,13 +162,13 @@ impl PartialOrd for PartialSum { } } -impl std::fmt::Debug for PartialSum { +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for PartialSum { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { for (k, v) in &self.0 { k.hash(state); @@ -171,38 +177,29 @@ impl Hash for PartialSum { } } -impl TryFrom for PartialSum { - type Error = ValueHandle; - - fn try_from(value: ValueHandle) -> Result { - value - .as_sum() - .map(|(tag, values)| Self::variant(tag, values.map(PartialValue::from))) - .ok_or(value) - } -} - #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub enum PartialValue { Bottom, - Value(ValueHandle), - PartialSum(PartialSum), + Value(V), + PartialSum(PartialSum), Top, } -impl From for PartialValue { - fn from(v: ValueHandle) -> Self { - TryInto::::try_into(v).map_or_else(Self::Value, Self::PartialSum) +impl From for PartialValue { + fn from(v: V) -> Self { + v.as_sum() + .map(|(tag, values)| Self::variant(tag, values.map(Self::Value))) + .unwrap_or(Self::Value(v)) } } -impl From for PartialValue { - fn from(v: PartialSum) -> Self { +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { Self::PartialSum(v) } } -impl PartialValue { +impl PartialValue { // const BOTTOM: Self = Self::Bottom; // const BOTTOM_REF: &'static Self = &Self::BOTTOM; @@ -220,23 +217,13 @@ impl PartialValue { ps.assert_variants(); } Self::Value(v) => { - assert!(matches!(v.clone().into(), Self::Value(_))) + assert!(v.as_sum().is_none()) } _ => {} } } - pub fn try_into_value(self, typ: &Type) -> Result { - let r = match self { - Self::Value(v) => Ok(v.value().clone()), - Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), - x => Err(x), - }?; - assert_eq!(typ, &r.get_type()); - Ok(r) - } - - fn join_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + fn join_mut_value_handle(&mut self, vh: V) -> bool { self.assert_invariants(); match &*self { Self::Top => return false, @@ -257,7 +244,7 @@ impl PartialValue { true } - fn meet_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + fn meet_mut_value_handle(&mut self, vh: V) -> bool { self.assert_invariants(); match &*self { Self::Bottom => false, @@ -412,7 +399,7 @@ impl PartialValue { Self::variant(0, []) } - pub fn variant_values(&self, tag: usize, len: usize) -> Option> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { PartialValue::Bottom => return None, PartialValue::Value(v) => v @@ -438,7 +425,19 @@ impl PartialValue { } } -impl PartialOrd for PartialValue { +impl> PartialValue { + pub fn try_into_value(self, typ: &Type) -> Result { + let r = match self { + Self::Value(v) => Ok(v.into().clone()), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + }?; + assert_eq!(typ, &r.get_type()); + Ok(r) + } +} + +impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; match (self, other) { diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs index 6621f0a69..33c8f3c8d 100644 --- a/hugr-passes/src/const_fold2/partial_value/test.rs +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -8,7 +8,9 @@ use hugr_core::{ types::{Type, TypeArg, TypeEnum}, }; -use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; +use super::{PartialSum, PartialValue}; +use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; + impl Arbitrary for ValueHandle { type Parameters = (); type Strategy = BoxedStrategy; @@ -48,7 +50,7 @@ impl TestSumLeafType { } } - fn type_check(&self, ps: &PartialSum) -> bool { + fn type_check(&self, ps: &PartialSum) -> bool { match self { Self::Int(_) => false, Self::Unit => { @@ -61,7 +63,7 @@ impl TestSumLeafType { } } - fn partial_value_strategy(self) -> impl Strategy { + fn partial_value_strategy(self) -> impl Strategy> { match self { Self::Int(t) => { let TypeEnum::Extension(ct) = t.as_type_enum() else { @@ -165,7 +167,7 @@ impl TestSumType { } } - fn type_check(&self, pv: &PartialValue) -> bool { + fn type_check(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), @@ -253,7 +255,7 @@ proptest! { } } -fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { +fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy> { ust.select().prop_flat_map(|x| match x { Either::Left(l) => l.partial_value_strategy().boxed(), Either::Right((index, usts)) => { @@ -273,15 +275,15 @@ fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy::Parameters, -) -> impl Strategy { +) -> impl Strategy> { any_with::(params).prop_flat_map(any_partial_value_of_type) } -fn any_partial_value() -> impl Strategy { +fn any_partial_value() -> impl Strategy> { any_partial_value_with(Default::default()) } -fn any_partial_values() -> impl Strategy { +fn any_partial_values() -> impl Strategy; N]> { any::().prop_flat_map(|ust| { TryInto::<[_; N]>::try_into( (0..N) @@ -292,7 +294,7 @@ fn any_partial_values() -> impl Strategy impl Strategy { +fn any_typed_partial_value() -> impl Strategy)> { any::() .prop_flat_map(|t| any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v))) } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs similarity index 95% rename from hugr-passes/src/const_fold2/partial_value/value_handle.rs rename to hugr-passes/src/const_fold2/value_handle.rs index 147048ae7..b5af487a8 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -2,11 +2,12 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; - use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; +use super::partial_value::{AbstractValue, PartialSum, PartialValue}; + #[derive(Clone, Debug)] pub struct HashedConst { hash: u64, @@ -78,7 +79,13 @@ impl ValueHandle { self.1.as_ref() } - pub fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { + pub fn get_type(&self) -> Type { + self.1.get_type() + } +} + +impl AbstractValue for ValueHandle { + fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { match self.value() { Value::Sum(Sum { tag, values, .. }) => Some(( *tag, @@ -90,10 +97,6 @@ impl ValueHandle { _ => None, } } - - pub fn get_type(&self) -> Type { - self.1.get_type() - } } impl PartialEq for ValueHandle { @@ -119,6 +122,12 @@ impl Hash for ValueHandle { } } +impl From for Value { + fn from(value: ValueHandle) -> Self { + (*value.1).clone() + } +} + #[cfg(test)] mod test { use hugr_core::{ From af8827b42b25cb4a1f5158844cb0ebbb40c4d49c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 12:16:41 +0100 Subject: [PATCH 021/281] Move partial_value.rs inside datalog/ --- hugr-passes/src/const_fold2.rs | 1 - hugr-passes/src/const_fold2/datalog.rs | 11 +++++++++-- .../src/const_fold2/{ => datalog}/partial_value.rs | 0 .../const_fold2/{ => datalog}/partial_value/test.rs | 0 hugr-passes/src/const_fold2/datalog/test.rs | 3 ++- hugr-passes/src/const_fold2/datalog/utils.rs | 8 +------- hugr-passes/src/const_fold2/value_handle.rs | 2 +- 7 files changed, 13 insertions(+), 12 deletions(-) rename hugr-passes/src/const_fold2/{ => datalog}/partial_value.rs (100%) rename hugr-passes/src/const_fold2/{ => datalog}/partial_value/test.rs (100%) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 13af5c709..7d6725fb1 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,3 +1,2 @@ pub mod datalog; -mod partial_value; pub mod value_handle; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index c06b2a285..fbe008d43 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -6,17 +6,24 @@ use hugr_core::ops::Value; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; +mod partial_value; mod utils; -pub use utils::{TailLoopTermination, ValueRow, IO, PV}; +use utils::{TailLoopTermination, ValueRow, PV}; -use super::partial_value::AbstractValue; +pub use partial_value::{AbstractValue, PartialSum, PartialValue}; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; fn value_from_load_constant(&self, node: Node) -> V; } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum IO { + Input, + Output, +} + ascent::ascent! { // The trait-indirection layer here means we can just write 'C' but in practice ATM // DataflowContext (for H: HugrView) would be sufficient, there's really no diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/datalog/partial_value.rs similarity index 100% rename from hugr-passes/src/const_fold2/partial_value.rs rename to hugr-passes/src/const_fold2/datalog/partial_value.rs diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/datalog/partial_value/test.rs similarity index 100% rename from hugr-passes/src/const_fold2/partial_value/test.rs rename to hugr-passes/src/const_fold2/datalog/partial_value/test.rs diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index f80cc903e..7e7057ffa 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -7,7 +7,8 @@ use hugr_core::{ types::{Signature, SumType, Type, TypeRow}, }; -use super::super::partial_value::PartialValue; +use super::partial_value::PartialValue; + use super::*; #[test] diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 42138396e..16b942e1d 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -8,7 +8,7 @@ use std::{cmp::Ordering, ops::Index}; use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; -use crate::const_fold2::partial_value::{AbstractValue, PartialValue}; +use crate::const_fold2::datalog::{AbstractValue, PartialValue}; use hugr_core::{ ops::{OpTrait as _, Value}, types::{Signature, Type, TypeRow}, @@ -229,12 +229,6 @@ pub(super) fn partial_value_tuple_from_value_row(r: ValueRow impl Iterator + '_ { h.in_value_types(n).map(|x| x.0) } diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index b5af487a8..b586f4cab 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use super::partial_value::{AbstractValue, PartialSum, PartialValue}; +use super::datalog::{AbstractValue, PartialSum, PartialValue}; #[derive(Clone, Debug)] pub struct HashedConst { From 4b614365bd1e2bf686eec1e432c66bd4a766ede2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 15:26:50 +0100 Subject: [PATCH 022/281] Hide PartialSum/PartialValue --- hugr-passes/src/const_fold2/datalog.rs | 2 +- hugr-passes/src/const_fold2/datalog/utils.rs | 2 +- hugr-passes/src/const_fold2/value_handle.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index fbe008d43..4ffbca32a 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -11,7 +11,7 @@ mod utils; use utils::{TailLoopTermination, ValueRow, PV}; -pub use partial_value::{AbstractValue, PartialSum, PartialValue}; +pub use partial_value::AbstractValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 16b942e1d..881963666 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -8,7 +8,7 @@ use std::{cmp::Ordering, ops::Index}; use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; -use crate::const_fold2::datalog::{AbstractValue, PartialValue}; +use super::{partial_value::PartialValue, AbstractValue}; use hugr_core::{ ops::{OpTrait as _, Value}, types::{Signature, Type, TypeRow}, diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index b586f4cab..2bc16994a 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use super::datalog::{AbstractValue, PartialSum, PartialValue}; +use super::datalog::AbstractValue; #[derive(Clone, Debug)] pub struct HashedConst { From 8b31d8c8abb847533cf989d8dc82927a18fd7d78 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 16:08:52 +0100 Subject: [PATCH 023/281] refactor: ValueRow::single_among_bottoms --- hugr-passes/src/const_fold2/datalog/utils.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 881963666..91fb1723c 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -105,17 +105,13 @@ impl ValueRow { Self(vec![PV::bottom(); len]) } - fn singleton(len: usize, idx: usize, v: PV) -> Self { + fn single_among_bottoms(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { - Self::singleton(r.len(), idx, v) - } - fn bottom_from_row(r: &TypeRow) -> Self { Self::new(r.len()) } @@ -222,7 +218,7 @@ pub(super) fn singleton_in_row( h.get_optype(*n).description() ); } - ValueRow::singleton_from_row(&h.signature(*n).unwrap().input, ip.index(), v) + ValueRow::single_among_bottoms(h.signature(*n).unwrap().input.len(), ip.index(), v) } pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { From 6a729613706ee8394027631409f59e57b6cee841 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 16:11:41 +0100 Subject: [PATCH 024/281] Factor out propagate_leaf_op; add ValueRow::from_iter --- hugr-passes/src/const_fold2/datalog.rs | 58 +++++++++----------- hugr-passes/src/const_fold2/datalog/utils.rs | 6 ++ 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 4ffbca32a..a636293d6 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -2,7 +2,7 @@ use ascent::lattice::BoundedLattice; use std::collections::HashMap; use std::hash::Hash; -use hugr_core::ops::Value; +use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; @@ -68,37 +68,12 @@ ascent::ascent! { node_in_value_row(c, n, utils::bottom_row(c.hugr(), *n)) <-- node(c, n); node_in_value_row(c, n, utils::singleton_in_row(c.hugr(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); - - // Per node-type rules - // TODO do all leaf ops with a rule - // define `fn propagate_leaf_op(Context, Node, ValueRow) -> ValueRow - - // LoadConstant - relation load_constant_node(C, Node); - load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); - - out_wire_value(c, n, 0.into(), PV::from(c.value_from_load_constant(*n))) <-- - load_constant_node(c, n); - - - // MakeTuple - relation make_tuple_node(C, Node); - make_tuple_node(c, n) <-- node(c, n), if c.get_optype(*n).is_make_tuple(); - - out_wire_value(c, n, 0.into(), utils::partial_value_tuple_from_value_row(vs.clone())) <-- - make_tuple_node(c, n), node_in_value_row(c, n, vs); - - - // UnpackTuple - relation unpack_tuple_node(C, Node); - unpack_tuple_node(c,n) <-- node(c, n), if c.get_optype(*n).is_unpack_tuple(); - out_wire_value(c, n, p, v) <-- - unpack_tuple_node(c, n), - in_wire_value(c, n, IncomingPort::from(0), tup), - if let Some(fields) = tup.variant_values(0, utils::value_outputs(c.hugr(),*n).count()), - for (p,v) in (0..).map(OutgoingPort::from).zip(fields); - + node(c, n), + if !c.get_optype(*n).is_container(), + node_in_value_row(c, n, vs), + if let Some(outs) = propagate_leaf_op(c, *n, vs.clone()), + for (p,v) in (0..).map(OutgoingPort::from).zip(outs); // DFG relation dfg_node(C, Node); @@ -177,6 +152,27 @@ ascent::ascent! { } +fn propagate_leaf_op( + c: &impl DFContext, + n: Node, + ins: ValueRow, +) -> Option> { + match c.get_optype(n) { + OpType::LoadConstant(_) => Some(ValueRow::from_iter([PV::from( + c.value_from_load_constant(n), + )])), // ins empty + OpType::MakeTuple(_) => Some(ValueRow::from_iter([ + utils::partial_value_tuple_from_value_row(ins), + ])), + OpType::UnpackTuple(_) => { + let [tup] = ins.into_iter().collect::>().try_into().unwrap(); + tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) + .map(ValueRow::from_iter) + } + _ => None, + } +} + // TODO This should probably be called 'Analyser' or something pub struct Machine>( AscentProgram, diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 91fb1723c..2719f19ce 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -135,6 +135,12 @@ impl ValueRow { // } } +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) From 780af9b5b838141707d1eb457b4b870b5b08944a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 16:19:44 +0100 Subject: [PATCH 025/281] Add handling for Tag --- hugr-passes/src/const_fold2/datalog.rs | 5 ++--- hugr-passes/src/const_fold2/datalog/utils.rs | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index a636293d6..2802e37a5 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -161,14 +161,13 @@ fn propagate_leaf_op( OpType::LoadConstant(_) => Some(ValueRow::from_iter([PV::from( c.value_from_load_constant(n), )])), // ins empty - OpType::MakeTuple(_) => Some(ValueRow::from_iter([ - utils::partial_value_tuple_from_value_row(ins), - ])), + OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant(0, ins)])), OpType::UnpackTuple(_) => { let [tup] = ins.into_iter().collect::>().try_into().unwrap(); tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) .map(ValueRow::from_iter) } + OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant(t.tag, ins)])), _ => None, } } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 2719f19ce..132172db0 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -36,6 +36,10 @@ impl From> for PV { } impl PV { + pub fn variant(tag: usize, r: impl IntoIterator>) -> Self { + PartialValue::variant(tag, r.into_iter().map(|x| x.0)).into() + } + pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { Some( self.0 @@ -227,10 +231,6 @@ pub(super) fn singleton_in_row( ValueRow::single_among_bottoms(h.signature(*n).unwrap().input.len(), ip.index(), v) } -pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { - PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() -} - pub(super) fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { h.in_value_types(n).map(|x| x.0) } From cabcf04ac018574fe06ea8be9b109234893589bf Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 17:25:55 +0100 Subject: [PATCH 026/281] Remove PV (use typedef in datalog.rs) --- hugr-passes/src/const_fold2/datalog.rs | 13 ++- hugr-passes/src/const_fold2/datalog/utils.rs | 101 +++++-------------- 2 files changed, 29 insertions(+), 85 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 2802e37a5..9de140cb1 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -9,9 +9,10 @@ mod context; mod partial_value; mod utils; -use utils::{TailLoopTermination, ValueRow, PV}; +use utils::{TailLoopTermination, ValueRow}; pub use partial_value::AbstractValue; +type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; @@ -190,11 +191,9 @@ impl> Machine { pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator)>) { assert!(self.1.is_none()); - self.0.out_wire_value_proto.extend( - wires - .into_iter() - .map(|(w, v)| (w.node(), w.source(), v.into())), - ); + self.0 + .out_wire_value_proto + .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); } pub fn run(&mut self, context: C) { @@ -205,7 +204,7 @@ impl> Machine { self.0 .out_wire_value .iter() - .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone().into())) + .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(), ) } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 132172db0..8fbb40c02 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -10,106 +10,51 @@ use itertools::zip_eq; use super::{partial_value::PartialValue, AbstractValue}; use hugr_core::{ - ops::{OpTrait as _, Value}, - types::{Signature, Type, TypeRow}, + ops::OpTrait as _, + types::{Signature, TypeRow}, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; #[cfg(test)] use proptest_derive::Arbitrary; -#[derive(PartialEq, Eq, Hash, Clone, Debug)] -pub struct PV(PartialValue); - -// Implement manually as PartialValue is PartialOrd even when V isn't -// (deriving PartialOrd conditions on V: PartialOrd, which is not necessary) -impl PartialOrd for PV { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl From> for PV { - fn from(inner: PartialValue) -> Self { - Self(inner) - } -} - -impl PV { - pub fn variant(tag: usize, r: impl IntoIterator>) -> Self { - PartialValue::variant(tag, r.into_iter().map(|x| x.0)).into() - } - - pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { - Some( - self.0 - .variant_values(variant, len)? - .into_iter() - .map(PV::from) - .collect(), - ) - } - - pub fn supports_tag(&self, tag: usize) -> bool { - self.0.supports_tag(tag) - } -} - -impl> PV { - pub fn try_into_value(self, ty: &Type) -> Result { - self.0.try_into_value(ty).map_err(Self) - } -} - -impl From> for PartialValue { - fn from(value: PV) -> Self { - value.0 - } -} - -impl From for PV { - fn from(inner: V) -> Self { - Self(inner.into()) - } -} - -impl Lattice for PV { +impl Lattice for PartialValue { fn meet(self, other: Self) -> Self { - self.0.meet(other.0).into() + self.meet(other) } fn meet_mut(&mut self, other: Self) -> bool { - self.0.meet_mut(other.0) + self.meet_mut(other) } fn join(self, other: Self) -> Self { - self.0.join(other.0).into() + self.join(other) } fn join_mut(&mut self, other: Self) -> bool { - self.0.join_mut(other.0) + self.join_mut(other) } } -impl BoundedLattice for PV { +impl BoundedLattice for PartialValue { fn bottom() -> Self { - PartialValue::bottom().into() + Self::bottom() } fn top() -> Self { - PartialValue::top().into() + Self::top() } } #[derive(PartialEq, Clone, Eq, Hash)] -pub struct ValueRow(Vec>); +pub struct ValueRow(Vec>); impl ValueRow { fn new(len: usize) -> Self { - Self(vec![PV::bottom(); len]) + Self(vec![PartialValue::bottom(); len]) } - fn single_among_bottoms(len: usize, idx: usize, v: PV) -> Self { + fn single_among_bottoms(len: usize, idx: usize, v: PartialValue) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; @@ -120,7 +65,7 @@ impl ValueRow { Self::new(r.len()) } - pub fn iter(&self) -> impl Iterator> { + pub fn iter(&self) -> impl Iterator> { self.0.iter() } @@ -128,7 +73,7 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option> + '_> { + ) -> Option> + '_> { self[0] .variant_values(variant, len) .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) @@ -139,8 +84,8 @@ impl ValueRow { // } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } @@ -182,9 +127,9 @@ impl Lattice for ValueRow { } impl IntoIterator for ValueRow { - type Item = PV; + type Item = PartialValue; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() @@ -193,9 +138,9 @@ impl IntoIterator for ValueRow { impl Index for ValueRow where - Vec>: Index, + Vec>: Index, { - type Output = > as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) @@ -215,7 +160,7 @@ pub(super) fn singleton_in_row( h: &impl HugrView, n: &Node, ip: &IncomingPort, - v: PV, + v: PartialValue, ) -> ValueRow { let Some(sig) = h.signature(*n) else { panic!("dougrulz"); @@ -248,7 +193,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - pub fn from_control_value(v: &PV) -> Self { + pub fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break && !may_continue { Self::ExactlyZeroContinues From 5e4a04fe1cbbd8487d2a2dd7d53babb3fd2887ae Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 18:27:32 +0100 Subject: [PATCH 027/281] Allow DFContext to interpret any leaf op (except MakeTuple/etc.); pub PV+PS --- hugr-passes/src/const_fold2/datalog.rs | 30 ++++++++++----- .../src/const_fold2/datalog/context.rs | 38 +++++++++++++------ 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 9de140cb1..e8a4ce621 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -11,12 +11,16 @@ mod utils; use utils::{TailLoopTermination, ValueRow}; -pub use partial_value::AbstractValue; +pub use partial_value::{AbstractValue, PartialSum, PartialValue}; type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; - fn value_from_load_constant(&self, node: Node) -> V; + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>>; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -73,7 +77,7 @@ ascent::ascent! { node(c, n), if !c.get_optype(*n).is_container(), node_in_value_row(c, n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, vs.clone()), + if let Some(outs) = propagate_leaf_op(c, *n, &vs[..]), for (p,v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -156,20 +160,26 @@ ascent::ascent! { fn propagate_leaf_op( c: &impl DFContext, n: Node, - ins: ValueRow, + ins: &[PV], ) -> Option> { match c.get_optype(n) { - OpType::LoadConstant(_) => Some(ValueRow::from_iter([PV::from( - c.value_from_load_constant(n), - )])), // ins empty - OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant(0, ins)])), + // Handle basics here. I guess we could allow DFContext to specify but at the least + // we'd want these ones to be easily available for reuse. + OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant( + 0, + ins.into_iter().cloned(), + )])), OpType::UnpackTuple(_) => { let [tup] = ins.into_iter().collect::>().try_into().unwrap(); tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) .map(ValueRow::from_iter) } - OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant(t.tag, ins)])), - _ => None, + OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( + t.tag, + ins.into_iter().cloned(), + )])), + OpType::Input(_) | OpType::Output(_) => None, // handled by parent + _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), } } diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 31a3233fd..f4d6b7ab8 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -2,12 +2,13 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::{Hugr, HugrView}; +use hugr_core::ops::OpType; +use hugr_core::{Hugr, HugrView, Node}; +// ALAN Note this probably belongs with ValueHandle, outside datalog +use super::{DFContext, PartialValue}; use crate::const_fold2::value_handle::ValueHandle; -use super::DFContext; - #[derive(Debug)] pub(super) struct DataflowContext(Arc); @@ -60,14 +61,27 @@ impl DFContext for DataflowContext { self.0.as_ref() } - fn value_from_load_constant(&self, node: hugr_core::Node) -> ValueHandle { - let load_op = self.0.get_optype(node).as_load_constant().unwrap(); - let const_node = self - .0 - .single_linked_output(node, load_op.constant_port()) - .unwrap() - .0; - let const_op = self.0.get_optype(const_node).as_const().unwrap(); - ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())) + fn interpret_leaf_op( + &self, + n: Node, + ins: &[PartialValue], + ) -> Option>> { + match self.0.get_optype(n) { + OpType::LoadConstant(load_op) => { + // ins empty as static edge, we need to find the constant ourselves + let const_node = self + .0 + .single_linked_output(n, load_op.constant_port()) + .unwrap() + .0; + let const_op = self.0.get_optype(const_node).as_const().unwrap(); + Some(vec![ValueHandle::new( + const_node.into(), + Arc::new(const_op.value().clone()), + ) + .into()]) + } + _ => None, + } } } From 145388653efe29df2ac959b1c2813ac39b98146f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 18:31:18 +0100 Subject: [PATCH 028/281] Also fold extension ops --- .../src/const_fold2/datalog/context.rs | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index f4d6b7ab8..7985e5410 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -2,12 +2,11 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::ops::OpType; -use hugr_core::{Hugr, HugrView, Node}; +use hugr_core::ops::{CustomOp, DataflowOpTrait, OpType}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, PortIndex}; -// ALAN Note this probably belongs with ValueHandle, outside datalog +use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; use super::{DFContext, PartialValue}; -use crate::const_fold2::value_handle::ValueHandle; #[derive(Debug)] pub(super) struct DataflowContext(Arc); @@ -81,6 +80,27 @@ impl DFContext for DataflowContext { ) .into()]) } + OpType::CustomOp(CustomOp::Extension(op)) => { + let sig = op.signature(); + let known_ins = sig + .input_types() + .into_iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + pv.clone() + .try_into_value(ty) + .map(|v| (IncomingPort::from(i), v)) + .ok() + }) + .collect::>(); + let outs = op.constant_fold(&known_ins)?; + let mut res = vec![PartialValue::bottom(); sig.output_count()]; + for (op, v) in outs { + res[op.index()] = ValueHandle::new(ValueKey::Node(n), Arc::new(v)).into() + } + Some(res) + } _ => None, } } From 221e96cc0f6fe571563b19b2ffb4fa0643045cc1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 14:13:08 +0100 Subject: [PATCH 029/281] Comment as_sum --- hugr-passes/src/const_fold2/datalog/partial_value.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hugr-passes/src/const_fold2/datalog/partial_value.rs b/hugr-passes/src/const_fold2/datalog/partial_value.rs index 6a2a614a8..8441027d4 100644 --- a/hugr-passes/src/const_fold2/datalog/partial_value.rs +++ b/hugr-passes/src/const_fold2/datalog/partial_value.rs @@ -7,7 +7,11 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +/// Aka, deconstructible into Sum (TryIntoSum ?) pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + /// We write this way to optimize query/inspection (is-it-a-sum), + /// at the cost of requiring more cloning during actual conversion + /// (inside the lazy Iterator, or for the error case, as Self remains) fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } From cd4e15c47467ab9a8894e705e140f381ae94478f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 10:44:19 +0100 Subject: [PATCH 030/281] Rename DataflowContext to HugrValueContext --- hugr-passes/src/const_fold2/datalog.rs | 4 --- .../src/const_fold2/datalog/context.rs | 25 +++++++++++-------- hugr-passes/src/const_fold2/datalog/test.rs | 16 ++++++------ 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index e8a4ce621..38a6dadbc 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -30,10 +30,6 @@ pub enum IO { } ascent::ascent! { - // The trait-indirection layer here means we can just write 'C' but in practice ATM - // DataflowContext (for H: HugrView) would be sufficient, there's really no - // point in using anything else yet. However DFContext will be useful when we - // move interpretation of nodes out into a trait method. struct AscentProgram>; relation context(C); relation out_wire_value_proto(Node, OutgoingPort, PV); diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 7985e5410..a17b69ab2 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -8,10 +8,13 @@ use hugr_core::{Hugr, HugrView, IncomingPort, Node, PortIndex}; use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; use super::{DFContext, PartialValue}; +/// An implementation of [DFContext] with [ValueHandle] +/// that just stores a Hugr (actually any [HugrView]), +/// (there is )no state for operation-interpretation). #[derive(Debug)] -pub(super) struct DataflowContext(Arc); +pub struct HugrValueContext(Arc); -impl DataflowContext { +impl HugrValueContext { pub fn new(hugr: H) -> Self { Self(Arc::new(hugr)) } @@ -19,35 +22,35 @@ impl DataflowContext { // Deriving Clone requires H:HugrView to implement Clone, // but we don't need that as we only clone the Arc. -impl Clone for DataflowContext { +impl Clone for HugrValueContext { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl Hash for DataflowContext { +impl Hash for HugrValueContext { fn hash(&self, _state: &mut I) {} } -impl PartialEq for DataflowContext { +impl PartialEq for HugrValueContext { fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DataflowContext (maybe cloned) + // Any AscentProgram should have only one DFContext (maybe cloned) assert!(Arc::ptr_eq(&self.0, &other.0)); true } } -impl Eq for DataflowContext {} +impl Eq for HugrValueContext {} -impl PartialOrd for DataflowContext { +impl PartialOrd for HugrValueContext { fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DataflowContext (maybe cloned) + // Any AscentProgram should have only one DFContext (maybe cloned) assert!(Arc::ptr_eq(&self.0, &other.0)); Some(std::cmp::Ordering::Equal) } } -impl Deref for DataflowContext { +impl Deref for HugrValueContext { type Target = Hugr; fn deref(&self) -> &Self::Target { @@ -55,7 +58,7 @@ impl Deref for DataflowContext { } } -impl DFContext for DataflowContext { +impl DFContext for HugrValueContext { fn hugr(&self) -> &impl HugrView { self.0.as_ref() } diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 7e7057ffa..f24ba67c5 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -1,4 +1,4 @@ -use context::DataflowContext; +use context::HugrValueContext; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, @@ -20,7 +20,7 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let x = machine.read_out_wire_value(&hugr, v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -39,7 +39,7 @@ fn test_unpack_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -58,7 +58,7 @@ fn test_unpack_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); assert_eq!(o_r, Value::true_val()); @@ -84,7 +84,7 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -116,7 +116,7 @@ fn test_tail_loop_always_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom().into()); @@ -166,7 +166,7 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -222,7 +222,7 @@ fn conditional() { let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); machine.propolutate_out_wires([(arg_w, arg_pv.into())]); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); From c5ab2a94def4cf80709ac85b01b838c8bfa46ae8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 10:52:10 +0100 Subject: [PATCH 031/281] Move {datalog=>value_handle}/context.rs - an impl, datalog uses only DFContext --- hugr-passes/src/const_fold2/datalog.rs | 1 - hugr-passes/src/const_fold2/datalog/test.rs | 3 ++- hugr-passes/src/const_fold2/value_handle.rs | 3 +++ .../src/const_fold2/{datalog => value_handle}/context.rs | 4 ++-- 4 files changed, 7 insertions(+), 4 deletions(-) rename hugr-passes/src/const_fold2/{datalog => value_handle}/context.rs (97%) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 38a6dadbc..69156a9d9 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -5,7 +5,6 @@ use std::hash::Hash; use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; -mod context; mod partial_value; mod utils; diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index f24ba67c5..1e3cdcc98 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -1,4 +1,5 @@ -use context::HugrValueContext; +use crate::const_fold2::value_handle::HugrValueContext; + use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 2bc16994a..daf8a98fd 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -8,6 +8,9 @@ use hugr_core::Node; use super::datalog::AbstractValue; +mod context; +pub use context::HugrValueContext; + #[derive(Clone, Debug)] pub struct HashedConst { hash: u64, diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs similarity index 97% rename from hugr-passes/src/const_fold2/datalog/context.rs rename to hugr-passes/src/const_fold2/value_handle/context.rs index a17b69ab2..06ced3238 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use hugr_core::ops::{CustomOp, DataflowOpTrait, OpType}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, PortIndex}; -use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; -use super::{DFContext, PartialValue}; +use super::{ValueHandle, ValueKey}; +use crate::const_fold2::datalog::{DFContext, PartialValue}; /// An implementation of [DFContext] with [ValueHandle] /// that just stores a Hugr (actually any [HugrView]), From 6c80acf500491f16722767e61d80c4803e2b8bbe Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 10:57:31 +0100 Subject: [PATCH 032/281] Comment re. propagate_leaf_op --- hugr-passes/src/const_fold2/datalog.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 69156a9d9..c924f36e0 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -158,8 +158,9 @@ fn propagate_leaf_op( ins: &[PV], ) -> Option> { match c.get_optype(n) { - // Handle basics here. I guess we could allow DFContext to specify but at the least - // we'd want these ones to be easily available for reuse. + // Handle basics here. I guess (given the current interface) we could allow + // DFContext to handle these but at the least we'd want these impls to be + // easily available for reuse. OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant( 0, ins.into_iter().cloned(), @@ -174,6 +175,9 @@ fn propagate_leaf_op( ins.into_iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) => None, // handled by parent + // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, + // thus keeping PartialValue hidden, but AbstractValues + // are not necessarily convertible to Value! _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), } } From 636f14dbece90f238f2cacfe923b6d9f1cab47c5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 14:28:51 +0100 Subject: [PATCH 033/281] Hide PartialValue by abstracting DFContext::InterpretableVal: FromSum (==Value) --- hugr-passes/src/const_fold2/datalog.rs | 39 ++++++++++++++---- .../src/const_fold2/datalog/partial_value.rs | 24 +++++------ hugr-passes/src/const_fold2/datalog/utils.rs | 12 +++++- hugr-passes/src/const_fold2/value_handle.rs | 8 +++- .../src/const_fold2/value_handle/context.rs | 41 ++++++------------- 5 files changed, 71 insertions(+), 53 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index c924f36e0..c6aad5840 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -2,24 +2,25 @@ use ascent::lattice::BoundedLattice; use std::collections::HashMap; use std::hash::Hash; -use hugr_core::ops::{OpType, Value}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use hugr_core::ops::{OpTrait, OpType}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; mod partial_value; mod utils; use utils::{TailLoopTermination, ValueRow}; -pub use partial_value::{AbstractValue, PartialSum, PartialValue}; +pub use partial_value::{AbstractValue, FromSum}; type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + type InterpretableVal: FromSum + From; fn hugr(&self) -> &impl HugrView; fn interpret_leaf_op( &self, node: Node, - ins: &[PartialValue], - ) -> Option>>; + ins: &[(IncomingPort, Self::InterpretableVal)], + ) -> Vec<(OutgoingPort, V)>; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -178,7 +179,29 @@ fn propagate_leaf_op( // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues // are not necessarily convertible to Value! - _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), + op => { + let sig = op.dataflow_signature()?; + let known_ins = sig + .input_types() + .into_iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + pv.clone() + .try_into_value(ty) + .ok() + .map(|v| (IncomingPort::from(i), v)) + }) + .collect::>(); + let known_outs = c.interpret_leaf_op(n, &known_ins); + (!known_outs.is_empty()).then(|| { + let mut res = ValueRow::new(sig.output_count()); + for (p, v) in known_outs { + res[p.index()] = v.into(); + } + res + }) + } } } @@ -241,10 +264,8 @@ impl> Machine { .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) .unwrap() } -} -impl, C: DFContext> Machine { - pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { + pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { // dbg!(&w); let pv = self.read_out_wire_partial_value(w)?; // dbg!(&pv); diff --git a/hugr-passes/src/const_fold2/datalog/partial_value.rs b/hugr-passes/src/const_fold2/datalog/partial_value.rs index 8441027d4..b8f9067d4 100644 --- a/hugr-passes/src/const_fold2/datalog/partial_value.rs +++ b/hugr-passes/src/const_fold2/datalog/partial_value.rs @@ -1,7 +1,6 @@ #![allow(missing_docs)] -use hugr_core::ops::Value; -use hugr_core::types::{Type, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeEnum, TypeRow}; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -15,6 +14,11 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } +pub trait FromSum { + fn new_sum(tag: usize, items: impl IntoIterator, st: &SumType) -> Self; + fn debug_check_is_type(&self, _ty: &Type) {} +} + // TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(pub HashMap>>); @@ -94,10 +98,8 @@ impl PartialSum { pub fn supports_tag(&self, tag: usize) -> bool { self.0.contains_key(&tag) } -} -impl> PartialSum { - pub fn try_into_value(self, typ: &Type) -> Result { + pub fn try_into_value>(self, typ: &Type) -> Result { let Ok((k, v)) = self.0.iter().exactly_one() else { Err(self)? }; @@ -118,7 +120,7 @@ impl> PartialSum { .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { - Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), + Ok(vs) => Ok(V2::new_sum(*k, vs, &st)), Err(_) => Err(self), } } @@ -427,16 +429,14 @@ impl PartialValue { PartialValue::Top => true, } } -} -impl> PartialValue { - pub fn try_into_value(self, typ: &Type) -> Result { - let r = match self { - Self::Value(v) => Ok(v.into().clone()), + pub fn try_into_value>(self, typ: &Type) -> Result { + let r: V2 = match self { + Self::Value(v) => Ok(v.clone().into()), Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), x => Err(x), }?; - assert_eq!(typ, &r.get_type()); + r.debug_check_is_type(typ); Ok(r) } } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 8fbb40c02..c0594fdb7 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -3,7 +3,7 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -use std::{cmp::Ordering, ops::Index}; +use std::{cmp::Ordering, ops::{Index, IndexMut}}; use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; @@ -50,7 +50,7 @@ impl BoundedLattice for PartialValue { pub struct ValueRow(Vec>); impl ValueRow { - fn new(len: usize) -> Self { + pub fn new(len: usize) -> Self { Self(vec![PartialValue::bottom(); len]) } @@ -147,6 +147,14 @@ where } } +impl IndexMut for ValueRow +where + Vec>: IndexMut { + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} + pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { ValueRow::new( h.signature(n) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index daf8a98fd..2aa1bc2fe 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use super::datalog::AbstractValue; +use super::datalog::{AbstractValue, FromSum}; mod context; pub use context::HugrValueContext; @@ -102,6 +102,12 @@ impl AbstractValue for ValueHandle { } } +impl FromSum for Value { + fn new_sum(tag: usize, items: impl IntoIterator, st: &hugr_core::types::SumType) -> Self { + Value::Sum(Sum {tag, values: items.into_iter().collect(), sum_type: st.clone()}) + } +} + impl PartialEq for ValueHandle { fn eq(&self, other: &Self) -> bool { // If the keys are equal, we return true since the values must have the diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs index 06ced3238..14571bb1b 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -2,11 +2,11 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::ops::{CustomOp, DataflowOpTrait, OpType}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, PortIndex}; +use hugr_core::ops::{CustomOp, OpType, Value}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::{ValueHandle, ValueKey}; -use crate::const_fold2::datalog::{DFContext, PartialValue}; +use crate::const_fold2::datalog::DFContext; /// An implementation of [DFContext] with [ValueHandle] /// that just stores a Hugr (actually any [HugrView]), @@ -59,6 +59,7 @@ impl Deref for HugrValueContext { } impl DFContext for HugrValueContext { + type InterpretableVal = Value; fn hugr(&self) -> &impl HugrView { self.0.as_ref() } @@ -66,45 +67,27 @@ impl DFContext for HugrValueContext { fn interpret_leaf_op( &self, n: Node, - ins: &[PartialValue], - ) -> Option>> { + ins: &[(IncomingPort, Value)], + ) -> Vec<(OutgoingPort,ValueHandle)> { match self.0.get_optype(n) { OpType::LoadConstant(load_op) => { - // ins empty as static edge, we need to find the constant ourselves + assert!(ins.is_empty()); // static edge, so need to find constant let const_node = self .0 .single_linked_output(n, load_op.constant_port()) .unwrap() .0; let const_op = self.0.get_optype(const_node).as_const().unwrap(); - Some(vec![ValueHandle::new( + vec![(OutgoingPort::from(0), ValueHandle::new( const_node.into(), Arc::new(const_op.value().clone()), - ) - .into()]) + ))] } OpType::CustomOp(CustomOp::Extension(op)) => { - let sig = op.signature(); - let known_ins = sig - .input_types() - .into_iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_value(ty) - .map(|v| (IncomingPort::from(i), v)) - .ok() - }) - .collect::>(); - let outs = op.constant_fold(&known_ins)?; - let mut res = vec![PartialValue::bottom(); sig.output_count()]; - for (op, v) in outs { - res[op.index()] = ValueHandle::new(ValueKey::Node(n), Arc::new(v)).into() - } - Some(res) + let ins = ins.into_iter().map(|(p,v)|(*p,v.clone())).collect::>(); + op.constant_fold(&ins).map_or(Vec::new(), |outs|outs.into_iter().map(|(p,v)|(p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))).collect()) } - _ => None, + _ => vec![], } } } From 0acdcc5767372a290cd58ce6febd95ef152dfb37 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 14:41:15 +0100 Subject: [PATCH 034/281] Use Value::sum, adding FromSum::Err; fmt --- .../src/const_fold2/datalog/partial_value.rs | 11 +++++++--- hugr-passes/src/const_fold2/datalog/utils.rs | 10 ++++++--- hugr-passes/src/const_fold2/value_handle.rs | 11 +++++++--- .../src/const_fold2/value_handle/context.rs | 21 ++++++++++++------- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog/partial_value.rs b/hugr-passes/src/const_fold2/datalog/partial_value.rs index b8f9067d4..73e962287 100644 --- a/hugr-passes/src/const_fold2/datalog/partial_value.rs +++ b/hugr-passes/src/const_fold2/datalog/partial_value.rs @@ -14,8 +14,13 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } -pub trait FromSum { - fn new_sum(tag: usize, items: impl IntoIterator, st: &SumType) -> Self; +pub trait FromSum: Sized { + type Err: std::error::Error; + fn try_new_sum( + tag: usize, + items: impl IntoIterator, + st: &SumType, + ) -> Result; fn debug_check_is_type(&self, _ty: &Type) {} } @@ -120,7 +125,7 @@ impl PartialSum { .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { - Ok(vs) => Ok(V2::new_sum(*k, vs, &st)), + Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self), Err(_) => Err(self), } } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index c0594fdb7..a5bc8bfa2 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -3,7 +3,10 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -use std::{cmp::Ordering, ops::{Index, IndexMut}}; +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; @@ -149,8 +152,9 @@ where impl IndexMut for ValueRow where - Vec>: IndexMut { - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) } } diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 2aa1bc2fe..e2bfc1930 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::Value; -use hugr_core::types::Type; +use hugr_core::types::{ConstTypeError, Type}; use hugr_core::Node; use super::datalog::{AbstractValue, FromSum}; @@ -103,8 +103,13 @@ impl AbstractValue for ValueHandle { } impl FromSum for Value { - fn new_sum(tag: usize, items: impl IntoIterator, st: &hugr_core::types::SumType) -> Self { - Value::Sum(Sum {tag, values: items.into_iter().collect(), sum_type: st.clone()}) + type Err = ConstTypeError; + fn try_new_sum( + tag: usize, + items: impl IntoIterator, + st: &hugr_core::types::SumType, + ) -> Result { + Self::sum(tag, items, st.clone()) } } diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs index 14571bb1b..7ac0b6ba5 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -68,7 +68,7 @@ impl DFContext for HugrValueContext { &self, n: Node, ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort,ValueHandle)> { + ) -> Vec<(OutgoingPort, ValueHandle)> { match self.0.get_optype(n) { OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant @@ -78,14 +78,21 @@ impl DFContext for HugrValueContext { .unwrap() .0; let const_op = self.0.get_optype(const_node).as_const().unwrap(); - vec![(OutgoingPort::from(0), ValueHandle::new( - const_node.into(), - Arc::new(const_op.value().clone()), - ))] + vec![( + OutgoingPort::from(0), + ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())), + )] } OpType::CustomOp(CustomOp::Extension(op)) => { - let ins = ins.into_iter().map(|(p,v)|(*p,v.clone())).collect::>(); - op.constant_fold(&ins).map_or(Vec::new(), |outs|outs.into_iter().map(|(p,v)|(p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))).collect()) + let ins = ins + .into_iter() + .map(|(p, v)| (*p, v.clone())) + .collect::>(); + op.constant_fold(&ins).map_or(Vec::new(), |outs| { + outs.into_iter() + .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))) + .collect() + }) } _ => vec![], } From 1012e0a618608d29b0ad47bbdb187c53945d96e5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 13:50:59 +0100 Subject: [PATCH 035/281] Cargo.toml: use explicit git= tag for ascent --- hugr-passes/Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index a6ed580c3..4234f7f95 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -15,7 +15,8 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.9.1" } portgraph = { workspace = true } -ascent = "0.6.0" +# This ascent commit has a fix for unsoundness in release/tag 0.6.0: +ascent = {git = "https://github.com/s-arash/ascent", rev="9805d02cb830b6e66abcd4d48836a14cd98366f3"} downcast-rs = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } From 6bd8dbaf39f385eb73707b05d046e377dc91cae0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 14:00:33 +0100 Subject: [PATCH 036/281] Fix rebase: TryHash, UnpackTuple/MakeTuple now in prelude --- hugr-passes/src/const_fold2/datalog.rs | 5 +++-- hugr-passes/src/const_fold2/datalog/test.rs | 5 +++-- hugr-passes/src/const_fold2/value_handle.rs | 2 +- hugr-passes/src/const_fold2/value_handle/context.rs | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index c6aad5840..a3256d41e 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,4 +1,5 @@ use ascent::lattice::BoundedLattice; +use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; use std::hash::Hash; @@ -162,11 +163,11 @@ fn propagate_leaf_op( // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be // easily available for reuse. - OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant( + op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::variant( 0, ins.into_iter().cloned(), )])), - OpType::UnpackTuple(_) => { + op if op.cast::().is_some() => { let [tup] = ins.into_iter().collect::>().try_into().unwrap(); tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) .map(ValueRow::from_iter) diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 1e3cdcc98..0e5d0c48e 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -2,8 +2,9 @@ use crate::const_fold2::value_handle::HugrValueContext; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, - extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, - ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, + extension::prelude::{UnpackTuple, BOOL_T}, + extension::{ExtensionSet, EMPTY_REG}, + ops::{handle::NodeHandle, OpTrait, Value}, type_row, types::{Signature, SumType, Type, TypeRow}, }; diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index e2bfc1930..e3dab5af0 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -57,7 +57,7 @@ impl ValueKey { pub fn try_new(cst: impl CustomConst) -> Option { let mut hasher = DefaultHasher::new(); - cst.maybe_hash(&mut hasher).then(|| { + cst.try_hash(&mut hasher).then(|| { Self::Const(HashedConst { hash: hasher.finish(), val: Arc::new(cst), diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs index 7ac0b6ba5..ccb7d27bd 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -2,7 +2,7 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::ops::{CustomOp, OpType, Value}; +use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::{ValueHandle, ValueKey}; @@ -83,7 +83,7 @@ impl DFContext for HugrValueContext { ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())), )] } - OpType::CustomOp(CustomOp::Extension(op)) => { + OpType::ExtensionOp(op) => { let ins = ins .into_iter() .map(|(p, v)| (*p, v.clone())) From f7d288f47d74215d0e9e1d3ba59604c793006d95 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 14:40:45 +0100 Subject: [PATCH 037/281] pub ValueRow+Partial(Value/Sum); add TotalContext --- hugr-passes/src/const_fold2.rs | 1 + hugr-passes/src/const_fold2/datalog.rs | 48 ++++++------------ hugr-passes/src/const_fold2/total_context.rs | 50 +++++++++++++++++++ .../src/const_fold2/value_handle/context.rs | 7 +-- 4 files changed, 67 insertions(+), 39 deletions(-) create mode 100644 hugr-passes/src/const_fold2/total_context.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 7d6725fb1..db1d99467 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,2 +1,3 @@ pub mod datalog; +pub mod total_context; pub mod value_handle; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index a3256d41e..0dc393620 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -3,25 +3,22 @@ use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; use std::hash::Hash; -use hugr_core::ops::{OpTrait, OpType}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; +use hugr_core::ops::{OpType, Value}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod partial_value; mod utils; -use utils::{TailLoopTermination, ValueRow}; +// TODO separate this into its own analysis? +use utils::TailLoopTermination; -pub use partial_value::{AbstractValue, FromSum}; +pub use partial_value::{AbstractValue, FromSum, PartialSum, PartialValue}; +pub use utils::ValueRow; type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - type InterpretableVal: FromSum + From; fn hugr(&self) -> &impl HugrView; - fn interpret_leaf_op( - &self, - node: Node, - ins: &[(IncomingPort, Self::InterpretableVal)], - ) -> Vec<(OutgoingPort, V)>; + fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -180,29 +177,7 @@ fn propagate_leaf_op( // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues // are not necessarily convertible to Value! - op => { - let sig = op.dataflow_signature()?; - let known_ins = sig - .input_types() - .into_iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_value(ty) - .ok() - .map(|v| (IncomingPort::from(i), v)) - }) - .collect::>(); - let known_outs = c.interpret_leaf_op(n, &known_ins); - (!known_outs.is_empty()).then(|| { - let mut res = ValueRow::new(sig.output_count()); - for (p, v) in known_outs { - res[p.index()] = v.into(); - } - res - }) - } + _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), } } @@ -265,8 +240,13 @@ impl> Machine { .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) .unwrap() } +} - pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { +impl> Machine +where + Value: From, +{ + pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { // dbg!(&w); let pv = self.read_out_wire_partial_value(w)?; // dbg!(&pv); diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/const_fold2/total_context.rs new file mode 100644 index 000000000..69bd516a4 --- /dev/null +++ b/hugr-passes/src/const_fold2/total_context.rs @@ -0,0 +1,50 @@ +use std::hash::Hash; + +use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; + +use super::datalog::{AbstractValue, DFContext, FromSum, PartialValue, ValueRow}; + +/// A simpler interface like [DFContext] but where the context only cares about +/// values that are completely known (in the lattice `V`) +/// rather than e.g. Sums potentially of two variants each of known values. +pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { + type InterpretableVal: FromSum + From; + fn interpret_leaf_op( + &self, + node: Node, + ins: &[(IncomingPort, Self::InterpretableVal)], + ) -> Vec<(OutgoingPort, V)>; +} + +impl> DFContext for T { + fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option> { + let op = self.get_optype(node); + let sig = op.dataflow_signature()?; + let known_ins = sig + .input_types() + .into_iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + pv.clone() + .try_into_value(ty) + .ok() + .map(|v| (IncomingPort::from(i), v)) + }) + .collect::>(); + let known_outs = self.interpret_leaf_op(node, &known_ins); + (!known_outs.is_empty()).then(|| { + let mut res = ValueRow::new(sig.output_count()); + for (p, v) in known_outs { + res[p.index()] = v.into(); + } + res + }) + } + + fn hugr(&self) -> &impl HugrView { + // Adding `fn hugr(&self) -> &impl HugrView` to trait TotalContext + // and calling that here requires a lifetime bound on V, so avoid that + self.as_ref() + } +} diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs index ccb7d27bd..b24c57aa8 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -6,7 +6,7 @@ use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::{ValueHandle, ValueKey}; -use crate::const_fold2::datalog::DFContext; +use crate::const_fold2::total_context::TotalContext; /// An implementation of [DFContext] with [ValueHandle] /// that just stores a Hugr (actually any [HugrView]), @@ -58,11 +58,8 @@ impl Deref for HugrValueContext { } } -impl DFContext for HugrValueContext { +impl TotalContext for HugrValueContext { type InterpretableVal = Value; - fn hugr(&self) -> &impl HugrView { - self.0.as_ref() - } fn interpret_leaf_op( &self, From a62eb0f5506b02065fe7723ea0e9c0796feeb6bb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 14:48:32 +0100 Subject: [PATCH 038/281] Remove DFContext::hugr(), as_ref() does just as well --- hugr-passes/src/const_fold2/datalog.rs | 11 +++++------ hugr-passes/src/const_fold2/total_context.rs | 6 ------ 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 0dc393620..4baef4a9b 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -17,7 +17,6 @@ pub use utils::ValueRow; type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - fn hugr(&self) -> &impl HugrView; fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; } @@ -43,9 +42,9 @@ ascent::ascent! { node(c, n) <-- context(c), for n in c.nodes(); - in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.hugr(), *n); + in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.as_ref(), *n); - out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.hugr(), *n); + out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.as_ref(), *n); parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.get_parent(*child); @@ -64,8 +63,8 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, utils::bottom_row(c.hugr(), *n)) <-- node(c, n); - node_in_value_row(c, n, utils::singleton_in_row(c.hugr(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, utils::bottom_row(c.as_ref(), *n)) <-- node(c, n); + node_in_value_row(c, n, utils::singleton_in_row(c.as_ref(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); out_wire_value(c, n, p, v) <-- node(c, n), @@ -166,7 +165,7 @@ fn propagate_leaf_op( )])), op if op.cast::().is_some() => { let [tup] = ins.into_iter().collect::>().try_into().unwrap(); - tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) + tup.variant_values(0, utils::value_outputs(c.as_ref(), n).count()) .map(ValueRow::from_iter) } OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/const_fold2/total_context.rs index 69bd516a4..2ccb7db88 100644 --- a/hugr-passes/src/const_fold2/total_context.rs +++ b/hugr-passes/src/const_fold2/total_context.rs @@ -41,10 +41,4 @@ impl> DFContext for T { res }) } - - fn hugr(&self) -> &impl HugrView { - // Adding `fn hugr(&self) -> &impl HugrView` to trait TotalContext - // and calling that here requires a lifetime bound on V, so avoid that - self.as_ref() - } } From f0ec2373fa253107cc3b6f5524d1c205cd541997 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 15:13:01 +0100 Subject: [PATCH 039/281] Move partial_value out of datalog, combine tests; move ValueRow out of utils --- hugr-passes/src/const_fold2.rs | 9 +- hugr-passes/src/const_fold2/datalog.rs | 13 +- .../const_fold2/datalog/partial_value/test.rs | 347 ----------------- hugr-passes/src/const_fold2/datalog/test.rs | 2 +- hugr-passes/src/const_fold2/datalog/utils.rs | 156 +------- .../{datalog => }/partial_value.rs | 357 +++++++++++++++++- hugr-passes/src/const_fold2/total_context.rs | 4 +- hugr-passes/src/const_fold2/value_handle.rs | 2 +- hugr-passes/src/const_fold2/value_row.rs | 117 ++++++ 9 files changed, 500 insertions(+), 507 deletions(-) delete mode 100644 hugr-passes/src/const_fold2/datalog/partial_value/test.rs rename hugr-passes/src/const_fold2/{datalog => }/partial_value.rs (54%) create mode 100644 hugr-passes/src/const_fold2/value_row.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index db1d99467..b0ab62fdc 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,3 +1,10 @@ -pub mod datalog; +mod datalog; +pub use datalog::Machine; + +pub mod partial_value; + +mod value_row; +pub use value_row::ValueRow; + pub mod total_context; pub mod value_handle; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 4baef4a9b..064aab91f 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -6,18 +6,17 @@ use std::hash::Hash; use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; -mod partial_value; mod utils; // TODO separate this into its own analysis? use utils::TailLoopTermination; -pub use partial_value::{AbstractValue, FromSum, PartialSum, PartialValue}; -pub use utils::ValueRow; -type PV = partial_value::PartialValue; +use super::partial_value::AbstractValue; +use super::value_row::ValueRow; +type PV = super::partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; + fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option>; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -63,8 +62,8 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, utils::bottom_row(c.as_ref(), *n)) <-- node(c, n); - node_in_value_row(c, n, utils::singleton_in_row(c.as_ref(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, ValueRow::new(utils::input_count(c.as_ref(), *n))) <-- node(c, n); + node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input.len(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); out_wire_value(c, n, p, v) <-- node(c, n), diff --git a/hugr-passes/src/const_fold2/datalog/partial_value/test.rs b/hugr-passes/src/const_fold2/datalog/partial_value/test.rs deleted file mode 100644 index 33c8f3c8d..000000000 --- a/hugr-passes/src/const_fold2/datalog/partial_value/test.rs +++ /dev/null @@ -1,347 +0,0 @@ -use std::sync::Arc; - -use itertools::{zip_eq, Either, Itertools as _}; -use proptest::prelude::*; - -use hugr_core::{ - std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, - types::{Type, TypeArg, TypeEnum}, -}; - -use super::{PartialSum, PartialValue}; -use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; - -impl Arbitrary for ValueHandle { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - // prop_oneof![ - - // ] - todo!() - } -} - -#[derive(Debug, PartialEq, Eq, Clone)] -enum TestSumLeafType { - Int(Type), - Unit, -} - -impl TestSumLeafType { - fn assert_invariants(&self) { - match self { - Self::Int(t) => { - if let TypeEnum::Extension(ct) = t.as_type_enum() { - assert_eq!("int", ct.name()); - assert_eq!(&int_types::EXTENSION_ID, ct.extension()); - } else { - panic!("Expected int type, got {:#?}", t); - } - } - _ => (), - } - } - - fn get_type(&self) -> Type { - match self { - Self::Int(t) => t.clone(), - Self::Unit => Type::UNIT, - } - } - - fn type_check(&self, ps: &PartialSum) -> bool { - match self { - Self::Int(_) => false, - Self::Unit => { - if let Ok((0, v)) = ps.0.iter().exactly_one() { - v.is_empty() - } else { - false - } - } - } - } - - fn partial_value_strategy(self) -> impl Strategy> { - match self { - Self::Int(t) => { - let TypeEnum::Extension(ct) = t.as_type_enum() else { - unreachable!() - }; - // TODO this should be get_log_width, but that's not pub - let TypeArg::BoundedNat { n: lw } = ct.args()[0] else { - panic!() - }; - (0u64..(1 << (2u64.pow(lw as u32) - 1))) - .prop_map(move |x| { - let ki = ConstInt::new_u(lw as u8, x).unwrap(); - let k = ValueKey::try_new(ki.clone()).unwrap(); - ValueHandle::new(k, Arc::new(ki.into())).into() - }) - .boxed() - } - Self::Unit => Just(PartialSum::unit().into()).boxed(), - } - } -} - -impl Arbitrary for TestSumLeafType { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - let int_strat = (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); - prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() - } -} - -#[derive(Debug, PartialEq, Eq, Clone)] -enum TestSumType { - Branch(usize, Vec>>), - Leaf(TestSumLeafType), -} - -impl TestSumType { - const UNIT: TestSumLeafType = TestSumLeafType::Unit; - - fn leaf(v: Type) -> Self { - TestSumType::Leaf(TestSumLeafType::Int(v)) - } - - fn branch(vs: impl IntoIterator>>) -> Self { - let vec = vs.into_iter().collect_vec(); - let depth: usize = vec - .iter() - .flat_map(|x| x.iter()) - .map(|x| x.depth() + 1) - .max() - .unwrap_or(0); - Self::Branch(depth, vec) - } - - fn depth(&self) -> usize { - match self { - TestSumType::Branch(x, _) => *x, - TestSumType::Leaf(_) => 0, - } - } - - fn is_leaf(&self) -> bool { - self.depth() == 0 - } - - fn assert_invariants(&self) { - match self { - TestSumType::Branch(d, sop) => { - assert!(!sop.is_empty(), "No variants"); - for v in sop.iter().flat_map(|x| x.iter()) { - assert!(v.depth() < *d); - v.assert_invariants(); - } - } - TestSumType::Leaf(l) => { - l.assert_invariants(); - } - } - } - - fn select(self) -> impl Strategy>)>> { - match self { - TestSumType::Branch(_, sop) => any::() - .prop_map(move |i| { - let index = i.index(sop.len()); - Either::Right((index, sop[index].clone())) - }) - .boxed(), - TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), - } - } - - fn get_type(&self) -> Type { - match self { - TestSumType::Branch(_, sop) => Type::new_sum( - sop.iter() - .map(|row| row.iter().map(|x| x.get_type()).collect_vec()), - ), - TestSumType::Leaf(l) => l.get_type(), - } - } - - fn type_check(&self, pv: &PartialValue) -> bool { - match (self, pv) { - (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), - (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { - for (k, v) in &ps.0 { - if *k >= sop.len() { - return false; - } - let prod = &sop[*k]; - if prod.len() != v.len() { - return false; - } - if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { - return false; - } - } - true - } - (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), - } - } -} - -impl From for TestSumType { - fn from(value: TestSumLeafType) -> Self { - Self::Leaf(value) - } -} - -#[derive(Clone, PartialEq, Eq, Debug)] -struct UnarySumTypeParams { - depth: usize, - branch_width: usize, -} - -impl UnarySumTypeParams { - pub fn descend(mut self, d: usize) -> Self { - assert!(d < self.depth); - self.depth = d; - self - } -} - -impl Default for UnarySumTypeParams { - fn default() -> Self { - Self { - depth: 3, - branch_width: 3, - } - } -} - -impl Arbitrary for TestSumType { - type Parameters = UnarySumTypeParams; - type Strategy = BoxedStrategy; - fn arbitrary_with( - params @ UnarySumTypeParams { - depth, - branch_width, - }: Self::Parameters, - ) -> Self::Strategy { - if depth == 0 { - any::().prop_map_into().boxed() - } else { - (0..depth) - .prop_flat_map(move |d| { - prop::collection::vec( - prop::collection::vec( - any_with::(params.clone().descend(d)).prop_map_into(), - 0..branch_width, - ), - 1..=branch_width, - ) - .prop_map(TestSumType::branch) - }) - .boxed() - } - } -} - -proptest! { - #[test] - fn unary_sum_type_valid(ust: TestSumType) { - ust.assert_invariants(); - } -} - -fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy> { - ust.select().prop_flat_map(|x| match x { - Either::Left(l) => l.partial_value_strategy().boxed(), - Either::Right((index, usts)) => { - let pvs = usts - .into_iter() - .map(|x| { - any_partial_value_of_type( - Arc::::try_unwrap(x).unwrap_or_else(|x| x.as_ref().clone()), - ) - }) - .collect_vec(); - pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) - .boxed() - } - }) -} - -fn any_partial_value_with( - params: ::Parameters, -) -> impl Strategy> { - any_with::(params).prop_flat_map(any_partial_value_of_type) -} - -fn any_partial_value() -> impl Strategy> { - any_partial_value_with(Default::default()) -} - -fn any_partial_values() -> impl Strategy; N]> { - any::().prop_flat_map(|ust| { - TryInto::<[_; N]>::try_into( - (0..N) - .map(|_| any_partial_value_of_type(ust.clone())) - .collect_vec(), - ) - .unwrap() - }) -} - -fn any_typed_partial_value() -> impl Strategy)> { - any::() - .prop_flat_map(|t| any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v))) -} - -proptest! { - #[test] - fn partial_value_type((tst, pv) in any_typed_partial_value()) { - prop_assert!(tst.type_check(&pv)) - } - - // todo: ValidHandle is valid - // todo: ValidHandle eq is an equivalence relation - - // todo: PartialValue PartialOrd is transitive - // todo: PartialValue eq is an equivalence relation - #[test] - fn partial_value_valid(pv in any_partial_value()) { - pv.assert_invariants(); - } - - #[test] - fn bounded_lattice(v in any_partial_value()) { - prop_assert!(v <= PartialValue::top()); - prop_assert!(v >= PartialValue::bottom()); - } - - #[test] - fn meet_join_self_noop(v1 in any_partial_value()) { - let mut subject = v1.clone(); - - assert_eq!(v1.clone(), v1.clone().join(v1.clone())); - assert!(!subject.join_mut(v1.clone())); - assert_eq!(subject, v1); - - assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); - assert!(!subject.meet_mut(v1.clone())); - assert_eq!(subject, v1); - } - - #[test] - fn lattice([v1,v2] in any_partial_values()) { - let meet = v1.clone().meet(v2.clone()); - prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); - prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); - - let join = v1.clone().join(v2.clone()); - prop_assert!(join >= v1, "join not >=: {:#?}", &join); - prop_assert!(join >= v2, "join not >=: {:#?}", &join); - } -} diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 0e5d0c48e..42c34c451 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -9,7 +9,7 @@ use hugr_core::{ types::{Signature, SumType, Type, TypeRow}, }; -use super::partial_value::PartialValue; +use super::super::partial_value::PartialValue; use super::*; diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index a5bc8bfa2..ae486e280 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -3,20 +3,12 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -use std::{ - cmp::Ordering, - ops::{Index, IndexMut}, -}; +use std::cmp::Ordering; use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::zip_eq; -use super::{partial_value::PartialValue, AbstractValue}; -use hugr_core::{ - ops::OpTrait as _, - types::{Signature, TypeRow}, - HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, -}; +use super::super::partial_value::{AbstractValue, PartialValue}; +use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort}; #[cfg(test)] use proptest_derive::Arbitrary; @@ -49,143 +41,11 @@ impl BoundedLattice for PartialValue { } } -#[derive(PartialEq, Clone, Eq, Hash)] -pub struct ValueRow(Vec>); - -impl ValueRow { - pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) - } - - fn single_among_bottoms(len: usize, idx: usize, v: PartialValue) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r - } - - fn bottom_from_row(r: &TypeRow) -> Self { - Self::new(r.len()) - } - - pub fn iter(&self) -> impl Iterator> { - self.0.iter() - } - - pub fn unpack_first( - &self, - variant: usize, - len: usize, - ) -> Option> + '_> { - self[0] - .variant_values(variant, len) - .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) - } - - // fn initialised(&self) -> bool { - // self.0.iter().all(|x| x != &PV::top()) - // } -} - -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - -impl PartialOrd for ValueRow { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl Lattice for ValueRow { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - fn join_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.join_mut(v2); - } - changed - } - - fn meet_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.meet_mut(v2); - } - changed - } -} - -impl IntoIterator for ValueRow { - type Item = PartialValue; - - type IntoIter = > as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl Index for ValueRow -where - Vec>: Index, -{ - type Output = > as Index>::Output; - - fn index(&self, index: Idx) -> &Self::Output { - self.0.index(index) - } -} - -impl IndexMut for ValueRow -where - Vec>: IndexMut, -{ - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { - self.0.index_mut(index) - } -} - -pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { - ValueRow::new( - h.signature(n) - .as_ref() - .map(Signature::input_count) - .unwrap_or(0), - ) -} - -pub(super) fn singleton_in_row( - h: &impl HugrView, - n: &Node, - ip: &IncomingPort, - v: PartialValue, -) -> ValueRow { - let Some(sig) = h.signature(*n) else { - panic!("dougrulz"); - }; - if sig.input_count() <= ip.index() { - panic!( - "bad port index: {} >= {}: {}", - ip.index(), - sig.input_count(), - h.get_optype(*n).description() - ); - } - ValueRow::single_among_bottoms(h.signature(*n).unwrap().input.len(), ip.index(), v) +pub(super) fn input_count(h: &impl HugrView, n: Node) -> usize { + h.signature(n) + .as_ref() + .map(Signature::input_count) + .unwrap_or(0) } pub(super) fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { diff --git a/hugr-passes/src/const_fold2/datalog/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs similarity index 54% rename from hugr-passes/src/const_fold2/datalog/partial_value.rs rename to hugr-passes/src/const_fold2/partial_value.rs index 73e962287..933376027 100644 --- a/hugr-passes/src/const_fold2/datalog/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -464,4 +464,359 @@ impl PartialOrd for PartialValue { } #[cfg(test)] -mod test; +mod test { + use std::sync::Arc; + + use itertools::{zip_eq, Either, Itertools as _}; + use proptest::prelude::*; + + use hugr_core::{ + std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, + types::{Type, TypeArg, TypeEnum}, + }; + + use super::{PartialSum, PartialValue}; + use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; + + impl Arbitrary for ValueHandle { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + // prop_oneof![ + + // ] + todo!() + } + } + + #[derive(Debug, PartialEq, Eq, Clone)] + enum TestSumLeafType { + Int(Type), + Unit, + } + + impl TestSumLeafType { + fn assert_invariants(&self) { + match self { + Self::Int(t) => { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); + } + } + _ => (), + } + } + + fn get_type(&self) -> Type { + match self { + Self::Int(t) => t.clone(), + Self::Unit => Type::UNIT, + } + } + + fn type_check(&self, ps: &PartialSum) -> bool { + match self { + Self::Int(_) => false, + Self::Unit => { + if let Ok((0, v)) = ps.0.iter().exactly_one() { + v.is_empty() + } else { + false + } + } + } + } + + fn partial_value_strategy(self) -> impl Strategy> { + match self { + Self::Int(t) => { + let TypeEnum::Extension(ct) = t.as_type_enum() else { + unreachable!() + }; + // TODO this should be get_log_width, but that's not pub + let TypeArg::BoundedNat { n: lw } = ct.args()[0] else { + panic!() + }; + (0u64..(1 << (2u64.pow(lw as u32) - 1))) + .prop_map(move |x| { + let ki = ConstInt::new_u(lw as u8, x).unwrap(); + let k = ValueKey::try_new(ki.clone()).unwrap(); + ValueHandle::new(k, Arc::new(ki.into())).into() + }) + .boxed() + } + Self::Unit => Just(PartialSum::unit().into()).boxed(), + } + } + } + + impl Arbitrary for TestSumLeafType { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + let int_strat = + (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); + prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() + } + } + + #[derive(Debug, PartialEq, Eq, Clone)] + enum TestSumType { + Branch(usize, Vec>>), + Leaf(TestSumLeafType), + } + + impl TestSumType { + const UNIT: TestSumLeafType = TestSumLeafType::Unit; + + fn leaf(v: Type) -> Self { + TestSumType::Leaf(TestSumLeafType::Int(v)) + } + + fn branch(vs: impl IntoIterator>>) -> Self { + let vec = vs.into_iter().collect_vec(); + let depth: usize = vec + .iter() + .flat_map(|x| x.iter()) + .map(|x| x.depth() + 1) + .max() + .unwrap_or(0); + Self::Branch(depth, vec) + } + + fn depth(&self) -> usize { + match self { + TestSumType::Branch(x, _) => *x, + TestSumType::Leaf(_) => 0, + } + } + + fn is_leaf(&self) -> bool { + self.depth() == 0 + } + + fn assert_invariants(&self) { + match self { + TestSumType::Branch(d, sop) => { + assert!(!sop.is_empty(), "No variants"); + for v in sop.iter().flat_map(|x| x.iter()) { + assert!(v.depth() < *d); + v.assert_invariants(); + } + } + TestSumType::Leaf(l) => { + l.assert_invariants(); + } + } + } + + fn select(self) -> impl Strategy>)>> { + match self { + TestSumType::Branch(_, sop) => any::() + .prop_map(move |i| { + let index = i.index(sop.len()); + Either::Right((index, sop[index].clone())) + }) + .boxed(), + TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), + } + } + + fn get_type(&self) -> Type { + match self { + TestSumType::Branch(_, sop) => Type::new_sum( + sop.iter() + .map(|row| row.iter().map(|x| x.get_type()).collect_vec()), + ), + TestSumType::Leaf(l) => l.get_type(), + } + } + + fn type_check(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), + (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { + return false; + } + } + true + } + (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), + } + } + } + + impl From for TestSumType { + fn from(value: TestSumLeafType) -> Self { + Self::Leaf(value) + } + } + + #[derive(Clone, PartialEq, Eq, Debug)] + struct UnarySumTypeParams { + depth: usize, + branch_width: usize, + } + + impl UnarySumTypeParams { + pub fn descend(mut self, d: usize) -> Self { + assert!(d < self.depth); + self.depth = d; + self + } + } + + impl Default for UnarySumTypeParams { + fn default() -> Self { + Self { + depth: 3, + branch_width: 3, + } + } + } + + impl Arbitrary for TestSumType { + type Parameters = UnarySumTypeParams; + type Strategy = BoxedStrategy; + fn arbitrary_with( + params @ UnarySumTypeParams { + depth, + branch_width, + }: Self::Parameters, + ) -> Self::Strategy { + if depth == 0 { + any::().prop_map_into().boxed() + } else { + (0..depth) + .prop_flat_map(move |d| { + prop::collection::vec( + prop::collection::vec( + any_with::(params.clone().descend(d)).prop_map_into(), + 0..branch_width, + ), + 1..=branch_width, + ) + .prop_map(TestSumType::branch) + }) + .boxed() + } + } + } + + proptest! { + #[test] + fn unary_sum_type_valid(ust: TestSumType) { + ust.assert_invariants(); + } + } + + fn any_partial_value_of_type( + ust: TestSumType, + ) -> impl Strategy> { + ust.select().prop_flat_map(|x| match x { + Either::Left(l) => l.partial_value_strategy().boxed(), + Either::Right((index, usts)) => { + let pvs = usts + .into_iter() + .map(|x| { + any_partial_value_of_type( + Arc::::try_unwrap(x) + .unwrap_or_else(|x| x.as_ref().clone()), + ) + }) + .collect_vec(); + pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + .boxed() + } + }) + } + + fn any_partial_value_with( + params: ::Parameters, + ) -> impl Strategy> { + any_with::(params).prop_flat_map(any_partial_value_of_type) + } + + fn any_partial_value() -> impl Strategy> { + any_partial_value_with(Default::default()) + } + + fn any_partial_values() -> impl Strategy; N]> + { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(ust.clone())) + .collect_vec(), + ) + .unwrap() + }) + } + + fn any_typed_partial_value() -> impl Strategy)> + { + any::().prop_flat_map(|t| { + any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v)) + }) + } + + proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.type_check(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); + } + + #[test] + fn meet_join_self_noop(v1 in any_partial_value()) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } + } +} diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/const_fold2/total_context.rs index 2ccb7db88..63a7b4965 100644 --- a/hugr-passes/src/const_fold2/total_context.rs +++ b/hugr-passes/src/const_fold2/total_context.rs @@ -2,7 +2,9 @@ use std::hash::Hash; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::datalog::{AbstractValue, DFContext, FromSum, PartialValue, ValueRow}; +use super::datalog::DFContext; +use super::partial_value::{AbstractValue, FromSum, PartialValue}; +use super::ValueRow; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index e3dab5af0..3d26d5ac8 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::{ConstTypeError, Type}; use hugr_core::Node; -use super::datalog::{AbstractValue, FromSum}; +use super::partial_value::{AbstractValue, FromSum}; mod context; pub use context::HugrValueContext; diff --git a/hugr-passes/src/const_fold2/value_row.rs b/hugr-passes/src/const_fold2/value_row.rs new file mode 100644 index 000000000..91b45052c --- /dev/null +++ b/hugr-passes/src/const_fold2/value_row.rs @@ -0,0 +1,117 @@ +// Really this is part of partial_value.rs + +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; + +use ascent::lattice::Lattice; +use itertools::zip_eq; + +use super::partial_value::{AbstractValue, PartialValue}; + +#[derive(PartialEq, Clone, Eq, Hash)] +pub struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + pub fn iter(&self) -> impl Iterator> { + self.0.iter() + } + + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option> + '_> { + self[0] + .variant_values(variant, len) + .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) + } + + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} From 82a3f22d256e41ac2a4e86c3dd4948695242bb3c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 15:21:44 +0100 Subject: [PATCH 040/281] Move FromSum and try_into_value into total_context.rs --- hugr-passes/src/const_fold2/partial_value.rs | 49 +------------- hugr-passes/src/const_fold2/total_context.rs | 68 +++++++++++++++++++- hugr-passes/src/const_fold2/value_handle.rs | 15 +---- 3 files changed, 70 insertions(+), 62 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 933376027..3c20ff965 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -1,7 +1,6 @@ #![allow(missing_docs)] -use hugr_core::types::{SumType, Type, TypeEnum, TypeRow}; -use itertools::{zip_eq, Itertools}; +use itertools::zip_eq; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; @@ -14,16 +13,6 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } -pub trait FromSum: Sized { - type Err: std::error::Error; - fn try_new_sum( - tag: usize, - items: impl IntoIterator, - st: &SumType, - ) -> Result; - fn debug_check_is_type(&self, _ty: &Type) {} -} - // TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(pub HashMap>>); @@ -103,32 +92,6 @@ impl PartialSum { pub fn supports_tag(&self, tag: usize) -> bool { self.0.contains_key(&tag) } - - pub fn try_into_value>(self, typ: &Type) -> Result { - let Ok((k, v)) = self.0.iter().exactly_one() else { - Err(self)? - }; - - let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(self)? - }; - let Some(r) = st.get_variant(*k) else { - Err(self)? - }; - let Ok(r): Result = r.clone().try_into() else { - Err(self)? - }; - if v.len() != r.len() { - return Err(self); - } - match zip_eq(v.into_iter(), r.into_iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - { - Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self), - Err(_) => Err(self), - } - } } impl PartialSum { @@ -434,16 +397,6 @@ impl PartialValue { PartialValue::Top => true, } } - - pub fn try_into_value>(self, typ: &Type) -> Result { - let r: V2 = match self { - Self::Value(v) => Ok(v.clone().into()), - Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), - x => Err(x), - }?; - r.debug_check_is_type(typ); - Ok(r) - } } impl PartialOrd for PartialValue { diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/const_fold2/total_context.rs index 63a7b4965..8caafda52 100644 --- a/hugr-passes/src/const_fold2/total_context.rs +++ b/hugr-passes/src/const_fold2/total_context.rs @@ -1,11 +1,24 @@ use std::hash::Hash; +use hugr_core::ops::Value; +use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; +use itertools::{zip_eq, Itertools}; use super::datalog::DFContext; -use super::partial_value::{AbstractValue, FromSum, PartialValue}; +use super::partial_value::{AbstractValue, PartialSum, PartialValue}; use super::ValueRow; +pub trait FromSum: Sized { + type Err: std::error::Error; + fn try_new_sum( + tag: usize, + items: impl IntoIterator, + st: &SumType, + ) -> Result; + fn debug_check_is_type(&self, _ty: &Type) {} +} + /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) /// rather than e.g. Sums potentially of two variants each of known values. @@ -18,6 +31,59 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { ) -> Vec<(OutgoingPort, V)>; } +impl FromSum for Value { + type Err = ConstTypeError; + fn try_new_sum( + tag: usize, + items: impl IntoIterator, + st: &hugr_core::types::SumType, + ) -> Result { + Self::sum(tag, items, st.clone()) + } +} + +// These are here because they rely on FromSum, that they are `impl PartialSum/Value` +// is merely a nice syntax. +impl PartialValue { + pub fn try_into_value>(self, typ: &Type) -> Result { + let r: V2 = match self { + Self::Value(v) => Ok(v.clone().into()), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + }?; + r.debug_check_is_type(typ); + Ok(r) + } +} + +impl PartialSum { + pub fn try_into_value>(self, typ: &Type) -> Result { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + let Ok(r): Result = r.clone().try_into() else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self), + Err(_) => Err(self), + } + } +} + impl> DFContext for T { fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option> { let op = self.get_optype(node); diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 3d26d5ac8..137699763 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, Type}; +use hugr_core::types::Type; use hugr_core::Node; -use super::partial_value::{AbstractValue, FromSum}; +use super::partial_value::AbstractValue; mod context; pub use context::HugrValueContext; @@ -102,17 +102,6 @@ impl AbstractValue for ValueHandle { } } -impl FromSum for Value { - type Err = ConstTypeError; - fn try_new_sum( - tag: usize, - items: impl IntoIterator, - st: &hugr_core::types::SumType, - ) -> Result { - Self::sum(tag, items, st.clone()) - } -} - impl PartialEq for ValueHandle { fn eq(&self, other: &Self) -> bool { // If the keys are equal, we return true since the values must have the From e6dc114b87fe97cbe45cbe88b17399d4c594a17b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 15:41:48 +0100 Subject: [PATCH 041/281] Separate mod dataflow from mod const_fold2 --- hugr-passes/src/const_fold2.rs | 13 +++++-------- .../src/const_fold2/{value_handle => }/context.rs | 4 ++-- hugr-passes/src/const_fold2/value_handle.rs | 5 +---- hugr-passes/src/dataflow.rs | 13 +++++++++++++ .../src/{const_fold2 => dataflow}/datalog.rs | 0 .../src/{const_fold2 => dataflow}/datalog/test.rs | 2 +- .../src/{const_fold2 => dataflow}/datalog/utils.rs | 0 .../src/{const_fold2 => dataflow}/partial_value.rs | 0 .../src/{const_fold2 => dataflow}/total_context.rs | 0 .../src/{const_fold2 => dataflow}/value_row.rs | 0 hugr-passes/src/lib.rs | 1 + 11 files changed, 23 insertions(+), 15 deletions(-) rename hugr-passes/src/const_fold2/{value_handle => }/context.rs (97%) create mode 100644 hugr-passes/src/dataflow.rs rename hugr-passes/src/{const_fold2 => dataflow}/datalog.rs (100%) rename hugr-passes/src/{const_fold2 => dataflow}/datalog/test.rs (99%) rename hugr-passes/src/{const_fold2 => dataflow}/datalog/utils.rs (100%) rename hugr-passes/src/{const_fold2 => dataflow}/partial_value.rs (100%) rename hugr-passes/src/{const_fold2 => dataflow}/total_context.rs (100%) rename hugr-passes/src/{const_fold2 => dataflow}/value_row.rs (100%) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index b0ab62fdc..1fa3498e0 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,10 +1,7 @@ -mod datalog; -pub use datalog::Machine; +//! An (example) use of the [super::dataflow](dataflow-analysis framework) +//! to perform constant-folding. -pub mod partial_value; - -mod value_row; -pub use value_row::ValueRow; - -pub mod total_context; +// These are pub because this "example" is used for testing the framework. pub mod value_handle; +mod context; +pub use context::HugrValueContext; diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/context.rs similarity index 97% rename from hugr-passes/src/const_fold2/value_handle/context.rs rename to hugr-passes/src/const_fold2/context.rs index b24c57aa8..007fcfa92 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; -use super::{ValueHandle, ValueKey}; -use crate::const_fold2::total_context::TotalContext; +use super::value_handle::{ValueHandle, ValueKey}; +use crate::dataflow::TotalContext; /// An implementation of [DFContext] with [ValueHandle] /// that just stores a Hugr (actually any [HugrView]), diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 137699763..4bacef114 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,10 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use super::partial_value::AbstractValue; - -mod context; -pub use context::HugrValueContext; +use crate::dataflow::AbstractValue; #[derive(Clone, Debug)] pub struct HashedConst { diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs new file mode 100644 index 000000000..ec00e7f9a --- /dev/null +++ b/hugr-passes/src/dataflow.rs @@ -0,0 +1,13 @@ +//! Dataflow analysis of Hugrs. + +mod datalog; +pub use datalog::Machine; + +mod partial_value; +pub use partial_value::{PartialValue, AbstractValue}; + +mod value_row; +pub use value_row::ValueRow; + +mod total_context; +pub use total_context::TotalContext; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/dataflow/datalog.rs similarity index 100% rename from hugr-passes/src/const_fold2/datalog.rs rename to hugr-passes/src/dataflow/datalog.rs diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs similarity index 99% rename from hugr-passes/src/const_fold2/datalog/test.rs rename to hugr-passes/src/dataflow/datalog/test.rs index 42c34c451..4f3a3b187 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -1,4 +1,4 @@ -use crate::const_fold2::value_handle::HugrValueContext; +use crate::const_fold2::HugrValueContext; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/dataflow/datalog/utils.rs similarity index 100% rename from hugr-passes/src/const_fold2/datalog/utils.rs rename to hugr-passes/src/dataflow/datalog/utils.rs diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs similarity index 100% rename from hugr-passes/src/const_fold2/partial_value.rs rename to hugr-passes/src/dataflow/partial_value.rs diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/dataflow/total_context.rs similarity index 100% rename from hugr-passes/src/const_fold2/total_context.rs rename to hugr-passes/src/dataflow/total_context.rs diff --git a/hugr-passes/src/const_fold2/value_row.rs b/hugr-passes/src/dataflow/value_row.rs similarity index 100% rename from hugr-passes/src/const_fold2/value_row.rs rename to hugr-passes/src/dataflow/value_row.rs diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 8949d8bd4..9bf576a5e 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod dataflow; pub mod const_fold2; pub mod force_order; mod half_node; From 03ff1658a62805ecbd0e2654ccc38035924d41d9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 15:44:49 +0100 Subject: [PATCH 042/281] fmt --- hugr-passes/src/const_fold2.rs | 2 +- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/lib.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 1fa3498e0..93b772d88 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -2,6 +2,6 @@ //! to perform constant-folding. // These are pub because this "example" is used for testing the framework. -pub mod value_handle; mod context; +pub mod value_handle; pub use context::HugrValueContext; diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index ec00e7f9a..8c0ad1d8c 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -4,7 +4,7 @@ mod datalog; pub use datalog::Machine; mod partial_value; -pub use partial_value::{PartialValue, AbstractValue}; +pub use partial_value::{AbstractValue, PartialValue}; mod value_row; pub use value_row::ValueRow; diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 9bf576a5e..0b73fcbb0 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,8 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; -pub mod dataflow; pub mod const_fold2; +pub mod dataflow; pub mod force_order; mod half_node; pub mod lower; From 25ed1fbb4efd7759ae07bf85827e2ef02c65fa61 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:11:14 +0100 Subject: [PATCH 043/281] TailLoopTermination just examine whatever PartialValue's we have, remove most --- hugr-passes/src/dataflow/datalog.rs | 23 ++--- hugr-passes/src/dataflow/datalog/test.rs | 2 +- hugr-passes/src/dataflow/datalog/utils.rs | 120 +--------------------- 3 files changed, 11 insertions(+), 134 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 064aab91f..c1655d280 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,4 +1,3 @@ -use ascent::lattice::BoundedLattice; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; use std::hash::Hash; @@ -108,15 +107,6 @@ ascent::ascent! { if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); - lattice tail_loop_termination(C,Node,TailLoopTermination); - tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- - tail_loop_node(c,tl_n); - tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <-- - tail_loop_node(c,tl_n), - io_node(c,tl,out_n, IO::Output), - in_wire_value(c, out_n, IncomingPort::from(0), v); - - // Conditional relation conditional_node(C, Node); relation case_node(C,Node,usize, Node); @@ -221,11 +211,14 @@ impl> Machine { pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { assert!(hugr.get_optype(node).is_tail_loop()); - self.0 - .tail_loop_termination - .iter() - .find_map(|(_, n, v)| (n == &node).then_some(*v)) - .unwrap() + let [_, out] = hugr.get_io(node).unwrap(); + TailLoopTermination::from_control_value( + self.0 + .in_wire_value + .iter() + .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .unwrap(), + ) } pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs index 4f3a3b187..e9e61fb9e 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -125,7 +125,7 @@ fn test_tail_loop_always_iterates() { let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom().into()); assert_eq!( - TailLoopTermination::bottom(), + TailLoopTermination::Bottom, machine.tail_loop_terminates(&hugr, tail_loop.node()) ) } diff --git a/hugr-passes/src/dataflow/datalog/utils.rs b/hugr-passes/src/dataflow/datalog/utils.rs index ae486e280..117d58628 100644 --- a/hugr-passes/src/dataflow/datalog/utils.rs +++ b/hugr-passes/src/dataflow/datalog/utils.rs @@ -1,18 +1,8 @@ -// proptest-derive generates many of these warnings. -// https://github.com/rust-lang/rust/issues/120363 -// https://github.com/proptest-rs/proptest/issues/447 -#![cfg_attr(test, allow(non_local_definitions))] - -use std::cmp::Ordering; - use ascent::lattice::{BoundedLattice, Lattice}; use super::super::partial_value::{AbstractValue, PartialValue}; use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort}; -#[cfg(test)] -use proptest_derive::Arbitrary; - impl Lattice for PartialValue { fn meet(self, other: Self) -> Self { self.meet(other) @@ -57,7 +47,6 @@ pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator Option { - if self == other { - return Some(std::cmp::Ordering::Equal); - }; - match (self, other) { - (Self::Bottom, _) => Some(Ordering::Less), - (_, Self::Bottom) => Some(Ordering::Greater), - (Self::Top, _) => Some(Ordering::Greater), - (_, Self::Top) => Some(Ordering::Less), - _ => None, - } - } -} - -impl Lattice for TailLoopTermination { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - fn meet_mut(&mut self, other: Self) -> bool { - // let new_self = &mut self; - match (*self).partial_cmp(&other) { - Some(Ordering::Greater) => { - *self = other; - true - } - Some(_) => false, - _ => { - *self = Self::Bottom; - true - } - } - } - - fn join_mut(&mut self, other: Self) -> bool { - match (*self).partial_cmp(&other) { - Some(Ordering::Less) => { - *self = other; - true - } - Some(_) => false, - _ => { - *self = Self::Top; - true - } - } - } -} - -impl BoundedLattice for TailLoopTermination { - fn bottom() -> Self { - Self::Bottom - } - - fn top() -> Self { - Self::Top - } -} - -#[cfg(test)] -#[cfg_attr(test, allow(non_local_definitions))] -mod test { - use super::*; - use proptest::prelude::*; - - proptest! { - #[test] - fn bounded_lattice(v: TailLoopTermination) { - prop_assert!(v <= TailLoopTermination::top()); - prop_assert!(v >= TailLoopTermination::bottom()); - } - - #[test] - fn meet_join_self_noop(v1: TailLoopTermination) { - let mut subject = v1.clone(); - - assert_eq!(v1.clone(), v1.clone().join(v1.clone())); - assert!(!subject.join_mut(v1.clone())); - assert_eq!(subject, v1); - - assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); - assert!(!subject.meet_mut(v1.clone())); - assert_eq!(subject, v1); - } - - #[test] - fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) { - let meet = v1.clone().meet(v2.clone()); - prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); - prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); - - let join = v1.clone().join(v2.clone()); - prop_assert!(join >= v1, "join not >=: {:#?}", &join); - prop_assert!(join >= v2, "join not >=: {:#?}", &join); + Self::Bottom } } } From 09911df3001979314ec1b6aade2356c9ed9e4c2f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:21:47 +0100 Subject: [PATCH 044/281] Drop non-(Bounded)Lattice impls of (join/meet)(_mut),top,bottom --- hugr-passes/src/dataflow/datalog.rs | 1 + hugr-passes/src/dataflow/datalog/test.rs | 1 + hugr-passes/src/dataflow/datalog/utils.rs | 30 -------- hugr-passes/src/dataflow/partial_value.rs | 92 +++++++++++------------ hugr-passes/src/dataflow/value_row.rs | 2 +- 5 files changed, 48 insertions(+), 78 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index c1655d280..d0983ee6c 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,3 +1,4 @@ +use ascent::lattice::BoundedLattice; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; use std::hash::Hash; diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs index e9e61fb9e..c6c4eda4e 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -1,5 +1,6 @@ use crate::const_fold2::HugrValueContext; +use ascent::lattice::BoundedLattice; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::prelude::{UnpackTuple, BOOL_T}, diff --git a/hugr-passes/src/dataflow/datalog/utils.rs b/hugr-passes/src/dataflow/datalog/utils.rs index 117d58628..4d3a056c6 100644 --- a/hugr-passes/src/dataflow/datalog/utils.rs +++ b/hugr-passes/src/dataflow/datalog/utils.rs @@ -1,36 +1,6 @@ -use ascent::lattice::{BoundedLattice, Lattice}; - use super::super::partial_value::{AbstractValue, PartialValue}; use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort}; -impl Lattice for PartialValue { - fn meet(self, other: Self) -> Self { - self.meet(other) - } - - fn meet_mut(&mut self, other: Self) -> bool { - self.meet_mut(other) - } - - fn join(self, other: Self) -> Self { - self.join(other) - } - - fn join_mut(&mut self, other: Self) -> bool { - self.join_mut(other) - } -} - -impl BoundedLattice for PartialValue { - fn bottom() -> Self { - Self::bottom() - } - - fn top() -> Self { - Self::top() - } -} - pub(super) fn input_count(h: &impl HugrView, n: Node) -> usize { h.signature(n) .as_ref() diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 3c20ff965..670a7e38c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,5 +1,7 @@ #![allow(missing_docs)] +use ascent::lattice::BoundedLattice; +use ascent::Lattice; use itertools::zip_eq; use std::cmp::Ordering; use std::collections::HashMap; @@ -176,15 +178,6 @@ impl From> for PartialValue { impl PartialValue { // const BOTTOM: Self = Self::Bottom; // const BOTTOM_REF: &'static Self = &Self::BOTTOM; - - // fn initialised(&self) -> bool { - // !self.is_top() - // } - - // fn is_top(&self) -> bool { - // self == &PartialValue::Top - // } - fn assert_invariants(&self) { match self { Self::PartialSum(ps) => { @@ -249,7 +242,42 @@ impl PartialValue { self } - pub fn join_mut(&mut self, other: Self) -> bool { + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::variant(tag, values).into() + } + + pub fn unit() -> Self { + Self::variant(0, []) + } + + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + let vals = match self { + PartialValue::Bottom => return None, + PartialValue::Value(v) => v + .as_sum() + .filter(|(variant, _)| tag == *variant)? + .1 + .map(PartialValue::Value) + .collect(), + PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, + PartialValue::Top => vec![PartialValue::Top; len], + }; + assert_eq!(vals.len(), len); + Some(vals) + } + + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom => false, + PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } +} + +impl Lattice for PartialValue { + fn join_mut(&mut self, other: Self) -> bool { // println!("join {self:?}\n{:?}", &other); let changed = match (&*self, other) { (Self::Top, _) => false, @@ -301,12 +329,12 @@ impl PartialValue { changed } - pub fn meet(mut self, other: Self) -> Self { + fn meet(mut self, other: Self) -> Self { self.meet_mut(other); self } - pub fn meet_mut(&mut self, other: Self) -> bool { + fn meet_mut(&mut self, other: Self) -> bool { let changed = match (&*self, other) { (Self::Bottom, _) => false, (_, other @ Self::Bottom) => { @@ -356,47 +384,16 @@ impl PartialValue { // } changed } +} - pub fn top() -> Self { +impl BoundedLattice for PartialValue { + fn top() -> Self { Self::Top } - pub fn bottom() -> Self { + fn bottom() -> Self { Self::Bottom } - - pub fn variant(tag: usize, values: impl IntoIterator) -> Self { - PartialSum::variant(tag, values).into() - } - - pub fn unit() -> Self { - Self::variant(0, []) - } - - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { - let vals = match self { - PartialValue::Bottom => return None, - PartialValue::Value(v) => v - .as_sum() - .filter(|(variant, _)| tag == *variant)? - .1 - .map(PartialValue::Value) - .collect(), - PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, - PartialValue::Top => vec![PartialValue::Top; len], - }; - assert_eq!(vals.len(), len); - Some(vals) - } - - pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom => false, - PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, - } - } } impl PartialOrd for PartialValue { @@ -420,6 +417,7 @@ impl PartialOrd for PartialValue { mod test { use std::sync::Arc; + use ascent::{lattice::BoundedLattice, Lattice}; use itertools::{zip_eq, Either, Itertools as _}; use proptest::prelude::*; diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 91b45052c..9f7b8bef7 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -5,7 +5,7 @@ use std::{ ops::{Index, IndexMut}, }; -use ascent::lattice::Lattice; +use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; use super::partial_value::{AbstractValue, PartialValue}; From 248fb0757c357a3168e1db66ea6cb6334d1896c0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:25:47 +0100 Subject: [PATCH 045/281] doc fixes, remove missing-docs for partial_value...how does it still work --- hugr-passes/src/const_fold2/context.rs | 8 +++++--- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/datalog.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 2 -- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 007fcfa92..f632427dd 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -8,9 +8,11 @@ use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::value_handle::{ValueHandle, ValueKey}; use crate::dataflow::TotalContext; -/// An implementation of [DFContext] with [ValueHandle] -/// that just stores a Hugr (actually any [HugrView]), -/// (there is )no state for operation-interpretation). +/// A context ([DFContext]) for doing analysis with [ValueHandle]s. +/// Just stores a Hugr (actually any [HugrView]), +/// (there is )no state for operation-interpretation. +/// +/// [DFContext]: crate::dataflow::DFContext #[derive(Debug)] pub struct HugrValueContext(Arc); diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 8c0ad1d8c..3ceda2570 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,7 +1,7 @@ //! Dataflow analysis of Hugrs. mod datalog; -pub use datalog::Machine; +pub use datalog::{DFContext, Machine}; mod partial_value; pub use partial_value::{AbstractValue, PartialValue}; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d0983ee6c..00cc59279 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -179,7 +179,7 @@ pub struct Machine>( /// Usage: /// 1. [Self::new()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values -/// 3. Exactly one [Self::run_hugr] to do the analysis +/// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] impl> Machine { pub fn new() -> Self { diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 670a7e38c..ee5c6b8ba 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,5 +1,3 @@ -#![allow(missing_docs)] - use ascent::lattice::BoundedLattice; use ascent::Lattice; use itertools::zip_eq; From e2ad079141f9ec528c4f79991a8f95ddd3f16a83 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:46:33 +0100 Subject: [PATCH 046/281] fix all warnings (inc Machine::new() -> impl Default) --- hugr-passes/src/const_fold2/context.rs | 5 +--- hugr-passes/src/dataflow/datalog.rs | 29 +++++++++++-------- hugr-passes/src/dataflow/datalog/test.rs | 20 ++++++------- hugr-passes/src/dataflow/partial_value.rs | 34 +++++++---------------- hugr-passes/src/dataflow/total_context.rs | 6 ++-- 5 files changed, 42 insertions(+), 52 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index f632427dd..c18f5430b 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -83,10 +83,7 @@ impl TotalContext for HugrValueContext { )] } OpType::ExtensionOp(op) => { - let ins = ins - .into_iter() - .map(|(p, v)| (*p, v.clone())) - .collect::>(); + let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); op.constant_fold(&ins).map_or(Vec::new(), |outs| { outs.into_iter() .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 00cc59279..17552ad59 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,3 +1,9 @@ +#![allow( + clippy::clone_on_copy, + clippy::unused_enumerate_index, + clippy::collapsible_if +)] + use ascent::lattice::BoundedLattice; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; @@ -149,18 +155,17 @@ fn propagate_leaf_op( // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be // easily available for reuse. - op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::variant( - 0, - ins.into_iter().cloned(), - )])), + op if op.cast::().is_some() => { + Some(ValueRow::from_iter([PV::variant(0, ins.iter().cloned())])) + } op if op.cast::().is_some() => { - let [tup] = ins.into_iter().collect::>().try_into().unwrap(); + let [tup] = ins.iter().collect::>().try_into().unwrap(); tup.variant_values(0, utils::value_outputs(c.as_ref(), n).count()) .map(ValueRow::from_iter) } OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( t.tag, - ins.into_iter().cloned(), + ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) => None, // handled by parent // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, @@ -170,22 +175,24 @@ fn propagate_leaf_op( } } -// TODO This should probably be called 'Analyser' or something pub struct Machine>( AscentProgram, Option>>, ); +/// derived-Default requires the context to be Defaultable, which is unnecessary +impl> Default for Machine { + fn default() -> Self { + Self(Default::default(), None) + } +} + /// Usage: /// 1. [Self::new()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] impl> Machine { - pub fn new() -> Self { - Self(Default::default(), None) - } - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator)>) { assert!(self.1.is_none()); self.0 diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs index c6c4eda4e..d531e0ffd 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -22,7 +22,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); let x = machine.read_out_wire_value(&hugr, v3).unwrap(); @@ -41,7 +41,7 @@ fn test_unpack_tuple() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); @@ -60,7 +60,7 @@ fn test_unpack_const() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); @@ -86,7 +86,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -118,7 +118,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); @@ -168,15 +168,15 @@ fn test_tail_loop_iterates_twice() { // we should be able to propagate their values let [o_w1, o_w2] = tail_loop.outputs_arr(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); // TODO these hould be the propagated values for now they will bt join(true,false) - let o_r1 = machine.read_out_wire_partial_value(o_w1).unwrap(); + let _ = machine.read_out_wire_partial_value(o_w1).unwrap(); // assert_eq!(o_r1, PartialValue::top()); - let o_r2 = machine.read_out_wire_partial_value(o_w2).unwrap(); + let _ = machine.read_out_wire_partial_value(o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( TailLoopTermination::Top, @@ -212,7 +212,7 @@ fn conditional() { let case2 = case2_b.finish_with_outputs([false_w, c2a]).unwrap(); let case3_b = cond_builder.case_builder(2).unwrap(); - let [c3_1, c3_2] = case3_b.input_wires_arr(); + let [c3_1, _c3_2] = case3_b.input_wires_arr(); let case3 = case3_b.finish_with_outputs([c3_1, false_w]).unwrap(); let cond = cond_builder.finish_sub_container().unwrap(); @@ -221,7 +221,7 @@ fn conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); machine.propolutate_out_wires([(arg_w, arg_pv.into())]); diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index ee5c6b8ba..f1f05f877 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -122,7 +122,7 @@ impl PartialOrd for PartialSum { return None; } for (k, lhs) in &self.0 { - let Some(rhs) = other.0.get(&k) else { + let Some(rhs) = other.0.get(k) else { unreachable!() }; match lhs.partial_cmp(rhs) { @@ -192,8 +192,10 @@ impl PartialValue { self.assert_invariants(); match &*self { Self::Top => return false, - Self::Value(v) if v == &vh => return false, Self::Value(v) => { + if v == &vh { + return false; + }; *self = Self::Top; } Self::PartialSum(_) => match vh.into() { @@ -277,7 +279,7 @@ impl PartialValue { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { // println!("join {self:?}\n{:?}", &other); - let changed = match (&*self, other) { + match (&*self, other) { (Self::Top, _) => false, (_, other @ Self::Top) => { *self = other; @@ -316,15 +318,7 @@ impl Lattice for PartialValue { self.join_mut_value_handle(old_self) } (_, Self::Value(h)) => self.join_mut_value_handle(h), - // (new_self, _) => { - // **new_self = Self::Top; - // false - // } - }; - // if changed { - // println!("join new self: {:?}", s); - // } - changed + } } fn meet(mut self, other: Self) -> Self { @@ -333,7 +327,7 @@ impl Lattice for PartialValue { } fn meet_mut(&mut self, other: Self) -> bool { - let changed = match (&*self, other) { + match (&*self, other) { (Self::Bottom, _) => false, (_, other @ Self::Bottom) => { *self = other; @@ -372,15 +366,7 @@ impl Lattice for PartialValue { self.meet_mut_value_handle(old_self) } (Self::PartialSum(_), Self::Value(h)) => self.meet_mut_value_handle(h), - // (new_self, _) => { - // **new_self = Self::Bottom; - // false - // } - }; - // if changed { - // println!("join new self: {:?}", s); - // } - changed + } } } @@ -519,8 +505,7 @@ mod test { } impl TestSumType { - const UNIT: TestSumLeafType = TestSumLeafType::Unit; - + #[allow(unused)] // ALAN ? fn leaf(v: Type) -> Self { TestSumType::Leaf(TestSumLeafType::Int(v)) } @@ -543,6 +528,7 @@ mod test { } } + #[allow(unused)] // ALAN ? fn is_leaf(&self) -> bool { self.depth() == 0 } diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 8caafda52..882175af0 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -74,11 +74,11 @@ impl PartialSum { if v.len() != r.len() { return Err(self); } - match zip_eq(v.into_iter(), r.into_iter()) + match zip_eq(v, r.iter()) .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { - Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self), + Ok(vs) => V2::try_new_sum(*k, vs, st).map_err(|_| self), Err(_) => Err(self), } } @@ -90,7 +90,7 @@ impl> DFContext for T { let sig = op.dataflow_signature()?; let known_ins = sig .input_types() - .into_iter() + .iter() .enumerate() .zip(ins.iter()) .filter_map(|((i, ty), pv)| { From 95e2dd952c6aab27c180eef90db003e19933caab Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:59:33 +0100 Subject: [PATCH 047/281] distribute utils.rs -> machine.rs --- hugr-passes/src/dataflow.rs | 10 +- hugr-passes/src/dataflow/datalog.rs | 115 ++++------------------ hugr-passes/src/dataflow/datalog/test.rs | 7 +- hugr-passes/src/dataflow/datalog/utils.rs | 37 ------- hugr-passes/src/dataflow/machine.rs | 110 +++++++++++++++++++++ hugr-passes/src/dataflow/total_context.rs | 2 +- 6 files changed, 145 insertions(+), 136 deletions(-) delete mode 100644 hugr-passes/src/dataflow/datalog/utils.rs create mode 100644 hugr-passes/src/dataflow/machine.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 3ceda2570..452d070be 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,7 +1,8 @@ //! Dataflow analysis of Hugrs. mod datalog; -pub use datalog::{DFContext, Machine}; +mod machine; +pub use machine::Machine; mod partial_value; pub use partial_value::{AbstractValue, PartialValue}; @@ -11,3 +12,10 @@ pub use value_row::ValueRow; mod total_context; pub use total_context::TotalContext; + +use hugr_core::{Hugr, Node}; +use std::hash::Hash; + +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; +} diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 17552ad59..411b1dbcc 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,3 +1,6 @@ +//! [ascent] datalog implementation of analysis. +//! Since ascent-(macro-)generated code generates a bunch of warnings, +//! keep code in here to a minimum. #![allow( clippy::clone_on_copy, clippy::unused_enumerate_index, @@ -6,25 +9,16 @@ use ascent::lattice::BoundedLattice; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use std::collections::HashMap; +use hugr_core::types::Signature; use std::hash::Hash; -use hugr_core::ops::{OpType, Value}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use hugr_core::ops::OpType; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; -mod utils; - -// TODO separate this into its own analysis? -use utils::TailLoopTermination; - -use super::partial_value::AbstractValue; use super::value_row::ValueRow; +use super::{AbstractValue, DFContext}; type PV = super::partial_value::PartialValue; -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option>; -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IO { Input, @@ -32,7 +26,7 @@ pub enum IO { } ascent::ascent! { - struct AscentProgram>; + pub(super) struct AscentProgram>; relation context(C); relation out_wire_value_proto(Node, OutgoingPort, PV); @@ -47,9 +41,9 @@ ascent::ascent! { node(c, n) <-- context(c), for n in c.nodes(); - in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.as_ref(), *n); + in_wire(c, n,p) <-- node(c, n), for p in value_inputs(c.as_ref(), *n); - out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.as_ref(), *n); + out_wire(c, n,p) <-- node(c, n), for p in value_outputs(c.as_ref(), *n); parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.get_parent(*child); @@ -68,7 +62,7 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, ValueRow::new(utils::input_count(c.as_ref(), *n))) <-- node(c, n); + node_in_value_row(c, n, ValueRow::new(input_count(c.as_ref(), *n))) <-- node(c, n); node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input.len(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); out_wire_value(c, n, p, v) <-- @@ -160,7 +154,7 @@ fn propagate_leaf_op( } op if op.cast::().is_some() => { let [tup] = ins.iter().collect::>().try_into().unwrap(); - tup.variant_values(0, utils::value_outputs(c.as_ref(), n).count()) + tup.variant_values(0, value_outputs(c.as_ref(), n).count()) .map(ValueRow::from_iter) } OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( @@ -175,86 +169,19 @@ fn propagate_leaf_op( } } -pub struct Machine>( - AscentProgram, - Option>>, -); - -/// derived-Default requires the context to be Defaultable, which is unnecessary -impl> Default for Machine { - fn default() -> Self { - Self(Default::default(), None) - } +fn input_count(h: &impl HugrView, n: Node) -> usize { + h.signature(n) + .as_ref() + .map(Signature::input_count) + .unwrap_or(0) } -/// Usage: -/// 1. [Self::new()] -/// 2. Zero or more [Self::propolutate_out_wires] with initial values -/// 3. Exactly one [Self::run] to do the analysis -/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] -impl> Machine { - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator)>) { - assert!(self.1.is_none()); - self.0 - .out_wire_value_proto - .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); - } - - pub fn run(&mut self, context: C) { - assert!(self.1.is_none()); - self.0.context.push((context,)); - self.0.run(); - self.1 = Some( - self.0 - .out_wire_value - .iter() - .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) - .collect(), - ) - } - - pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { - self.1.as_ref().unwrap().get(&w).cloned() - } - - pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { - assert!(hugr.get_optype(node).is_tail_loop()); - let [_, out] = hugr.get_io(node).unwrap(); - TailLoopTermination::from_control_value( - self.0 - .in_wire_value - .iter() - .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) - .unwrap(), - ) - } - - pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { - assert!(hugr.get_optype(case).is_case()); - let cond = hugr.get_parent(case).unwrap(); - assert!(hugr.get_optype(cond).is_conditional()); - self.0 - .case_reachable - .iter() - .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) - .unwrap() - } +fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { + h.in_value_types(n).map(|x| x.0) } -impl> Machine -where - Value: From, -{ - pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { - // dbg!(&w); - let pv = self.read_out_wire_partial_value(w)?; - // dbg!(&pv); - let (_, typ) = hugr - .out_value_types(w.node()) - .find(|(p, _)| *p == w.source()) - .unwrap(); - pv.try_into_value(&typ).ok() - } +fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { + h.out_value_types(n).map(|x| x.0) } #[cfg(test)] diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs index d531e0ffd..68708084a 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -1,4 +1,7 @@ -use crate::const_fold2::HugrValueContext; +use crate::{ + const_fold2::HugrValueContext, + dataflow::{machine::TailLoopTermination, Machine}, +}; use ascent::lattice::BoundedLattice; use hugr_core::{ @@ -12,8 +15,6 @@ use hugr_core::{ use super::super::partial_value::PartialValue; -use super::*; - #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); diff --git a/hugr-passes/src/dataflow/datalog/utils.rs b/hugr-passes/src/dataflow/datalog/utils.rs deleted file mode 100644 index 4d3a056c6..000000000 --- a/hugr-passes/src/dataflow/datalog/utils.rs +++ /dev/null @@ -1,37 +0,0 @@ -use super::super::partial_value::{AbstractValue, PartialValue}; -use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort}; - -pub(super) fn input_count(h: &impl HugrView, n: Node) -> usize { - h.signature(n) - .as_ref() - .map(Signature::input_count) - .unwrap_or(0) -} - -pub(super) fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { - h.in_value_types(n).map(|x| x.0) -} - -pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { - h.out_value_types(n).map(|x| x.0) -} - -#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] -pub enum TailLoopTermination { - Bottom, - ExactlyZeroContinues, - Top, -} - -impl TailLoopTermination { - pub fn from_control_value(v: &PartialValue) -> Self { - let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); - if may_break && !may_continue { - Self::ExactlyZeroContinues - } else if may_break && may_continue { - Self::Top - } else { - Self::Bottom - } - } -} diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs new file mode 100644 index 000000000..b6736a336 --- /dev/null +++ b/hugr-passes/src/dataflow/machine.rs @@ -0,0 +1,110 @@ +use std::collections::HashMap; + +use hugr_core::{ops::Value, HugrView, Node, PortIndex, Wire}; + +use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; + +pub struct Machine>( + AscentProgram, + Option>>, +); + +/// derived-Default requires the context to be Defaultable, which is unnecessary +impl> Default for Machine { + fn default() -> Self { + Self(Default::default(), None) + } +} + +/// Usage: +/// 1. [Self::new()] +/// 2. Zero or more [Self::propolutate_out_wires] with initial values +/// 3. Exactly one [Self::run] to do the analysis +/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] +impl> Machine { + pub fn propolutate_out_wires( + &mut self, + wires: impl IntoIterator)>, + ) { + assert!(self.1.is_none()); + self.0 + .out_wire_value_proto + .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); + } + + pub fn run(&mut self, context: C) { + assert!(self.1.is_none()); + self.0.context.push((context,)); + self.0.run(); + self.1 = Some( + self.0 + .out_wire_value + .iter() + .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(), + ) + } + + pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { + self.1.as_ref().unwrap().get(&w).cloned() + } + + pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { + assert!(hugr.get_optype(node).is_tail_loop()); + let [_, out] = hugr.get_io(node).unwrap(); + TailLoopTermination::from_control_value( + self.0 + .in_wire_value + .iter() + .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .unwrap(), + ) + } + + pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { + assert!(hugr.get_optype(case).is_case()); + let cond = hugr.get_parent(case).unwrap(); + assert!(hugr.get_optype(cond).is_conditional()); + self.0 + .case_reachable + .iter() + .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) + .unwrap() + } +} + +impl> Machine +where + Value: From, +{ + pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { + // dbg!(&w); + let pv = self.read_out_wire_partial_value(w)?; + // dbg!(&pv); + let (_, typ) = hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + pv.try_into_value(&typ).ok() + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum TailLoopTermination { + Bottom, + ExactlyZeroContinues, + Top, +} + +impl TailLoopTermination { + pub fn from_control_value(v: &PartialValue) -> Self { + let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); + if may_break && !may_continue { + Self::ExactlyZeroContinues + } else if may_break && may_continue { + Self::Top + } else { + Self::Bottom + } + } +} diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 882175af0..89067105b 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -5,8 +5,8 @@ use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; use itertools::{zip_eq, Itertools}; -use super::datalog::DFContext; use super::partial_value::{AbstractValue, PartialSum, PartialValue}; +use super::DFContext; use super::ValueRow; pub trait FromSum: Sized { From 7f1e122587d4832b5e263d35d99ff205748d7ef1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 17:03:14 +0100 Subject: [PATCH 048/281] Move dataflow{/datalog=>}/test.rs --- hugr-passes/src/dataflow.rs | 3 +++ hugr-passes/src/dataflow/datalog.rs | 3 --- hugr-passes/src/dataflow/{datalog => }/test.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename hugr-passes/src/dataflow/{datalog => }/test.rs (99%) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 452d070be..827489144 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -19,3 +19,6 @@ use std::hash::Hash; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; } + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 411b1dbcc..06f49150f 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -183,6 +183,3 @@ fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator impl Iterator + '_ { h.out_value_types(n).map(|x| x.0) } - -#[cfg(test)] -mod test; diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/test.rs similarity index 99% rename from hugr-passes/src/dataflow/datalog/test.rs rename to hugr-passes/src/dataflow/test.rs index 68708084a..738e64073 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -13,7 +13,7 @@ use hugr_core::{ types::{Signature, SumType, Type, TypeRow}, }; -use super::super::partial_value::PartialValue; +use super::partial_value::PartialValue; #[test] fn test_make_tuple() { From faff5560dae64f17f17f0c748d3121f4a4c6497c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 17:09:52 +0100 Subject: [PATCH 049/281] and more warnings --- hugr-passes/src/const_fold2/value_handle.rs | 4 ++-- hugr-passes/src/dataflow/partial_value.rs | 15 ++++++--------- hugr-passes/src/dataflow/test.rs | 6 +++--- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 4bacef114..f24a4b734 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -157,7 +157,7 @@ mod test { assert_ne!(k1, k3); assert_eq!(ValueKey::from(n), ValueKey::from(n)); - let f = ConstF64::new(3.141); + let f = ConstF64::new(std::f64::consts::PI); assert_eq!(ValueKey::new(n, f.clone()), ValueKey::from(n)); assert_ne!(ValueKey::new(n, f.clone()), ValueKey::new(n2, f)); // Node taken into account @@ -182,7 +182,7 @@ mod test { fn value_key_list() { let v1 = ConstInt::new_u(3, 3).unwrap(); let v2 = ConstInt::new_u(4, 3).unwrap(); - let v3 = ConstF64::new(3.141); + let v3 = ConstF64::new(std::f64::consts::PI); let n = Node::from(portgraph::NodeIndex::new(0)); let n2: Node = portgraph::NodeIndex::new(1).into(); diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f1f05f877..d792154c2 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -432,16 +432,13 @@ mod test { impl TestSumLeafType { fn assert_invariants(&self) { - match self { - Self::Int(t) => { - if let TypeEnum::Extension(ct) = t.as_type_enum() { - assert_eq!("int", ct.name()); - assert_eq!(&int_types::EXTENSION_ID, ct.extension()); - } else { - panic!("Expected int type, got {:#?}", t); - } + if let Self::Int(t) = self { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); } - _ => (), } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 738e64073..fd683eaa8 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -123,9 +123,9 @@ fn test_tail_loop_always_iterates() { machine.run(HugrValueContext::new(&hugr)); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); - assert_eq!(o_r1, PartialValue::bottom().into()); + assert_eq!(o_r1, PartialValue::bottom()); let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); - assert_eq!(o_r2, PartialValue::bottom().into()); + assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( TailLoopTermination::Bottom, machine.tail_loop_terminates(&hugr, tail_loop.node()) @@ -225,7 +225,7 @@ fn conditional() { let mut machine = Machine::default(); let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); - machine.propolutate_out_wires([(arg_w, arg_pv.into())]); + machine.propolutate_out_wires([(arg_w, arg_pv)]); machine.run(HugrValueContext::new(&hugr)); let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); From 401354dacdbcede64246138914d2894b9e58a7a5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 17:14:45 +0100 Subject: [PATCH 050/281] fix extension tests --- hugr-passes/src/dataflow/test.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index fd683eaa8..749de5dd3 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -5,7 +5,7 @@ use crate::{ use ascent::lattice::BoundedLattice; use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, + builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::prelude::{UnpackTuple, BOOL_T}, extension::{ExtensionSet, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, Value}, @@ -17,7 +17,7 @@ use super::partial_value::PartialValue; #[test] fn test_make_tuple() { - let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); @@ -32,7 +32,7 @@ fn test_make_tuple() { #[test] fn test_unpack_tuple() { - let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); @@ -53,7 +53,7 @@ fn test_unpack_tuple() { #[test] fn test_unpack_const() { - let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); let [o] = builder .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) From c468387e5df3cfbb7e7c7000c4d89eafa76d8ab1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 17:19:11 +0100 Subject: [PATCH 051/281] Fix doclink, fix DefaultHasher pre-1.76 --- hugr-passes/src/const_fold2/value_handle.rs | 3 ++- hugr-passes/src/dataflow/machine.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index f24a4b734..7b6e26106 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -1,4 +1,5 @@ -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index b6736a336..6fe79208b 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -17,7 +17,7 @@ impl> Default for Machine { } /// Usage: -/// 1. [Self::new()] +/// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] From 88db5b18a7cb6b33be8e3229361c7188e3b856d5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 16:47:50 +0100 Subject: [PATCH 052/281] comment conditional test --- hugr-passes/src/dataflow/test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 749de5dd3..d8d8698af 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -232,7 +232,7 @@ fn conditional() { assert_eq!(cond_r1, Value::false_val()); assert!(machine.read_out_wire_value(&hugr, cond_o2).is_none()); - assert!(!machine.case_reachable(&hugr, case1.node())); + assert!(!machine.case_reachable(&hugr, case1.node())); // arg_pv is variant 1 or 2 only assert!(machine.case_reachable(&hugr, case2.node())); assert!(machine.case_reachable(&hugr, case3.node())); } From 738b61b168933bfc097eaa286f6d7eafa88cb360 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 16:48:19 +0100 Subject: [PATCH 053/281] Clarify (TODO untested) branches of join_mut --- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d792154c2..4a00de876 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -310,14 +310,14 @@ impl Lattice for PartialValue { } } } - (Self::Value(_), mut other) => { + (Self::Value(_), mut other@Self::PartialSum(_)) => { std::mem::swap(self, &mut other); let Self::Value(old_self) = other else { unreachable!() }; self.join_mut_value_handle(old_self) } - (_, Self::Value(h)) => self.join_mut_value_handle(h), + (Self::PartialSum(_), Self::Value(h)) => self.join_mut_value_handle(h), } } From 8f9c1ed852b22c25cb36fac01c1c484bd29206fb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 17:29:17 +0100 Subject: [PATCH 054/281] Exploit invariant PartialValue::Value is not a sum (even single known variant) --- hugr-passes/src/dataflow/partial_value.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 4a00de876..bcae3fc79 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -151,6 +151,8 @@ impl Hash for PartialSum { } } +/// We really must prevent people from constructing PartialValue::Value of +/// any `value` where `value.as_sum().is_some()`` #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub enum PartialValue { Bottom, @@ -253,12 +255,10 @@ impl PartialValue { pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { PartialValue::Bottom => return None, - PartialValue::Value(v) => v - .as_sum() - .filter(|(variant, _)| tag == *variant)? - .1 - .map(PartialValue::Value) - .collect(), + PartialValue::Value(v) => { + assert!(v.as_sum().is_none()); + return None; + } PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, PartialValue::Top => vec![PartialValue::Top; len], }; @@ -269,7 +269,10 @@ impl PartialValue { pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, - PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), + PartialValue::Value(v) => { + assert!(v.as_sum().is_none()); + false + } PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } @@ -310,7 +313,7 @@ impl Lattice for PartialValue { } } } - (Self::Value(_), mut other@Self::PartialSum(_)) => { + (Self::Value(_), mut other @ Self::PartialSum(_)) => { std::mem::swap(self, &mut other); let Self::Value(old_self) = other else { unreachable!() From a71ba9743e81f4d496a076b3df9681d3a6ef78b6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 17:31:06 +0100 Subject: [PATCH 055/281] Exploit invariant more, RIP join_mut_value_handle --- hugr-passes/src/dataflow/partial_value.rs | 35 ++++------------------- 1 file changed, 5 insertions(+), 30 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index bcae3fc79..ce1b8cd9a 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -190,29 +190,6 @@ impl PartialValue { } } - fn join_mut_value_handle(&mut self, vh: V) -> bool { - self.assert_invariants(); - match &*self { - Self::Top => return false, - Self::Value(v) => { - if v == &vh { - return false; - }; - *self = Self::Top; - } - Self::PartialSum(_) => match vh.into() { - Self::Value(_) => { - *self = Self::Top; - } - other => return self.join_mut(other), - }, - Self::Bottom => { - *self = vh.into(); - } - }; - true - } - fn meet_mut_value_handle(&mut self, vh: V) -> bool { self.assert_invariants(); match &*self { @@ -313,14 +290,12 @@ impl Lattice for PartialValue { } } } - (Self::Value(_), mut other @ Self::PartialSum(_)) => { - std::mem::swap(self, &mut other); - let Self::Value(old_self) = other else { - unreachable!() - }; - self.join_mut_value_handle(old_self) + (Self::Value(ref v), Self::PartialSum(_)) + | (Self::PartialSum(_), Self::Value(ref v)) => { + assert!(v.as_sum().is_none()); + *self = Self::Top; + true } - (Self::PartialSum(_), Self::Value(h)) => self.join_mut_value_handle(h), } } From 05280a8efd30442caf4d020ef0371dbb14ab1ca1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 17:34:29 +0100 Subject: [PATCH 056/281] By similar logic, RIP meet_mut_value_handle; assert_(in)variants now unused --- hugr-passes/src/dataflow/partial_value.rs | 38 +++-------------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index ce1b8cd9a..d9baf2c63 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -190,32 +190,6 @@ impl PartialValue { } } - fn meet_mut_value_handle(&mut self, vh: V) -> bool { - self.assert_invariants(); - match &*self { - Self::Bottom => false, - Self::Value(v) => { - if v == &vh { - false - } else { - *self = Self::Bottom; - true - } - } - Self::PartialSum(_) => match vh.into() { - Self::Value(_) => { - *self = Self::Bottom; - true - } - other => self.meet_mut(other), - }, - Self::Top => { - *self = vh.into(); - true - } - } - } - pub fn join(mut self, other: Self) -> Self { self.join_mut(other); self @@ -336,14 +310,12 @@ impl Lattice for PartialValue { } } } - (Self::Value(_), mut other @ Self::PartialSum(_)) => { - std::mem::swap(self, &mut other); - let Self::Value(old_self) = other else { - unreachable!() - }; - self.meet_mut_value_handle(old_self) + (Self::Value(ref v), Self::PartialSum(_)) + | (Self::PartialSum(_), Self::Value(ref v)) => { + assert!(v.as_sum().is_none()); + *self = Self::Bottom; + true } - (Self::PartialSum(_), Self::Value(h)) => self.meet_mut_value_handle(h), } } } From 96b0856f41bec7ea5c0680bfc8c15993c885b0eb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 13 Sep 2024 10:50:36 +0100 Subject: [PATCH 057/281] Rename TestSum(,Leaf)Type::assert_{invariants=>valid} --- hugr-passes/src/dataflow/partial_value.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d9baf2c63..319371a5d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -381,7 +381,7 @@ mod test { } impl TestSumLeafType { - fn assert_invariants(&self) { + fn assert_valid(&self) { if let Self::Int(t) = self { if let TypeEnum::Extension(ct) = t.as_type_enum() { assert_eq!("int", ct.name()); @@ -480,17 +480,17 @@ mod test { self.depth() == 0 } - fn assert_invariants(&self) { + fn assert_valid(&self) { match self { TestSumType::Branch(d, sop) => { assert!(!sop.is_empty(), "No variants"); for v in sop.iter().flat_map(|x| x.iter()) { assert!(v.depth() < *d); - v.assert_invariants(); + v.assert_valid(); } } TestSumType::Leaf(l) => { - l.assert_invariants(); + l.assert_valid(); } } } @@ -601,7 +601,7 @@ mod test { proptest! { #[test] fn unary_sum_type_valid(ust: TestSumType) { - ust.assert_invariants(); + ust.assert_valid(); } } From f21e278dff9985dd3759c54a271aaa8861d3c604 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 13 Sep 2024 10:55:35 +0100 Subject: [PATCH 058/281] Rename assert_(=>in)variants (i.e. to match); call in (join/meet)_mut; fix! --- hugr-passes/src/dataflow/partial_value.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 319371a5d..3e1b654ed 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -31,7 +31,7 @@ impl PartialSum { } impl PartialSum { - fn assert_variants(&self) { + fn assert_invariants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { pv.assert_invariants(); @@ -164,7 +164,7 @@ pub enum PartialValue { impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() - .map(|(tag, values)| Self::variant(tag, values.map(Self::Value))) + .map(|(tag, values)| Self::variant(tag, values.map(Self::from))) .unwrap_or(Self::Value(v)) } } @@ -181,7 +181,7 @@ impl PartialValue { fn assert_invariants(&self) { match self { Self::PartialSum(ps) => { - ps.assert_variants(); + ps.assert_invariants(); } Self::Value(v) => { assert!(v.as_sum().is_none()) @@ -232,6 +232,7 @@ impl PartialValue { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); match (&*self, other) { (Self::Top, _) => false, @@ -279,6 +280,7 @@ impl Lattice for PartialValue { } fn meet_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); match (&*self, other) { (Self::Bottom, _) => false, (_, other @ Self::Bottom) => { From ce53b1cc4832b1f13f9f171efc87d5bba5bdbd9f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 13 Sep 2024 11:16:40 +0100 Subject: [PATCH 059/281] test unpacking constant tuple --- hugr-passes/src/dataflow/test.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index d8d8698af..e4b3d5c24 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -31,13 +31,11 @@ fn test_make_tuple() { } #[test] -fn test_unpack_tuple() { +fn test_unpack_tuple_const() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); - let v1 = builder.add_load_value(Value::false_val()); - let v2 = builder.add_load_value(Value::true_val()); - let v3 = builder.make_tuple([v1, v2]).unwrap(); + let v = builder.add_load_value(Value::tuple([Value::false_val(), Value::true_val()])); let [o1, o2] = builder - .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v3]) + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v]) .unwrap() .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); From 13f29a909e58d4b1cf64fb49df786b6983ede9f9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 12 Sep 2024 21:25:40 +0100 Subject: [PATCH 060/281] try_into_value returns new enum ValueOrSum; TryFrom replaces FromSum --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 28 ++++++++- hugr-passes/src/dataflow/partial_value.rs | 51 +++++++++++++++- hugr-passes/src/dataflow/total_context.rs | 73 ++--------------------- 4 files changed, 81 insertions(+), 73 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 827489144..52ecca9e1 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -5,7 +5,7 @@ mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PartialValue}; +pub use partial_value::{AbstractValue, PartialValue, ValueOrSum}; mod value_row; pub use value_row::ValueRow; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 6fe79208b..d4aa97e78 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; -use hugr_core::{ops::Value, HugrView, Node, PortIndex, Wire}; +use hugr_core::{ops::Value, types::ConstTypeError, HugrView, Node, PortIndex, Wire}; -use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; +use super::{ + datalog::AscentProgram, partial_value::ValueOrSum, AbstractValue, DFContext, PartialValue, +}; pub struct Machine>( AscentProgram, @@ -85,7 +87,27 @@ where .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - pv.try_into_value(&typ).ok() + let v: ValueOrSum = pv.try_into_value(&typ).ok()?; + v.try_into().ok() + } +} + +impl TryFrom> for Value +where + Value: From, +{ + type Error = ConstTypeError; + fn try_from(value: ValueOrSum) -> Result { + match value { + ValueOrSum::Value(v) => Ok(v.into()), + ValueOrSum::Sum { tag, items, st } => { + let items = items + .into_iter() + .map(Value::try_from) + .collect::, _>>()?; + Value::sum(tag, items, st.clone()) + } + } } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 3e1b654ed..e6b53b935 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,6 +1,7 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use itertools::zip_eq; +use hugr_core::types::{SumType, Type, TypeEnum, TypeRow}; +use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; @@ -13,6 +14,16 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueOrSum { + Value(V), + Sum { + tag: usize, + items: Vec, + st: SumType, + }, +} + // TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(pub HashMap>>); @@ -92,6 +103,36 @@ impl PartialSum { pub fn supports_tag(&self, tag: usize) -> bool { self.0.contains_key(&tag) } + + pub fn try_into_value(self, typ: &Type) -> Result, Self> { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + let Ok(r) = TypeRow::try_from(r.clone()) else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v, r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => Ok(ValueOrSum::Sum { + tag: *k, + items: vs, + st: st.clone(), + }), + Err(_) => Err(self), + } + } } impl PartialSum { @@ -228,6 +269,14 @@ impl PartialValue { PartialValue::Top => true, } } + + pub fn try_into_value(self, typ: &Type) -> Result, Self> { + match self { + Self::Value(v) => Ok(ValueOrSum::Value(v.clone())), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + } + } } impl Lattice for PartialValue { diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 89067105b..26acc31ed 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -1,29 +1,16 @@ use std::hash::Hash; -use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use itertools::{zip_eq, Itertools}; -use super::partial_value::{AbstractValue, PartialSum, PartialValue}; +use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; use super::DFContext; use super::ValueRow; -pub trait FromSum: Sized { - type Err: std::error::Error; - fn try_new_sum( - tag: usize, - items: impl IntoIterator, - st: &SumType, - ) -> Result; - fn debug_check_is_type(&self, _ty: &Type) {} -} - /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) /// rather than e.g. Sums potentially of two variants each of known values. pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { - type InterpretableVal: FromSum + From; + type InterpretableVal: TryFrom>; fn interpret_leaf_op( &self, node: Node, @@ -31,59 +18,6 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { ) -> Vec<(OutgoingPort, V)>; } -impl FromSum for Value { - type Err = ConstTypeError; - fn try_new_sum( - tag: usize, - items: impl IntoIterator, - st: &hugr_core::types::SumType, - ) -> Result { - Self::sum(tag, items, st.clone()) - } -} - -// These are here because they rely on FromSum, that they are `impl PartialSum/Value` -// is merely a nice syntax. -impl PartialValue { - pub fn try_into_value>(self, typ: &Type) -> Result { - let r: V2 = match self { - Self::Value(v) => Ok(v.clone().into()), - Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), - x => Err(x), - }?; - r.debug_check_is_type(typ); - Ok(r) - } -} - -impl PartialSum { - pub fn try_into_value>(self, typ: &Type) -> Result { - let Ok((k, v)) = self.0.iter().exactly_one() else { - Err(self)? - }; - - let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(self)? - }; - let Some(r) = st.get_variant(*k) else { - Err(self)? - }; - let Ok(r): Result = r.clone().try_into() else { - Err(self)? - }; - if v.len() != r.len() { - return Err(self); - } - match zip_eq(v, r.iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - { - Ok(vs) => V2::try_new_sum(*k, vs, st).map_err(|_| self), - Err(_) => Err(self), - } - } -} - impl> DFContext for T { fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option> { let op = self.get_optype(node); @@ -96,7 +30,10 @@ impl> DFContext for T { .filter_map(|((i, ty), pv)| { pv.clone() .try_into_value(ty) + // Discard PVs which don't produce ValueOrSum, i.e. Bottom/Top :-) .ok() + // And discard any ValueOrSum that don't produce V - this is a bit silent :-( + .and_then(|v_s| T::InterpretableVal::try_from(v_s).ok()) .map(|v| (IncomingPort::from(i), v)) }) .collect::>(); From 9f1a5cda508988b5883816b62b3b3ad25e6f1c37 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 11:43:26 +0100 Subject: [PATCH 061/281] Hide ValueRow (and move into datalog.rs) --- hugr-passes/src/dataflow.rs | 10 +- hugr-passes/src/dataflow/datalog.rs | 122 +++++++++++++++++++++- hugr-passes/src/dataflow/total_context.rs | 9 +- hugr-passes/src/dataflow/value_row.rs | 117 --------------------- 4 files changed, 129 insertions(+), 129 deletions(-) delete mode 100644 hugr-passes/src/dataflow/value_row.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 52ecca9e1..2e2e95936 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,15 +1,13 @@ //! Dataflow analysis of Hugrs. mod datalog; + mod machine; pub use machine::Machine; mod partial_value; pub use partial_value::{AbstractValue, PartialValue, ValueOrSum}; -mod value_row; -pub use value_row::ValueRow; - mod total_context; pub use total_context::TotalContext; @@ -17,7 +15,11 @@ use hugr_core::{Hugr, Node}; use std::hash::Hash; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>>; } #[cfg(test)] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 06f49150f..c4c798324 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,16 +7,20 @@ clippy::collapsible_if )] -use ascent::lattice::BoundedLattice; -use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::types::Signature; +use ascent::lattice::{BoundedLattice, Lattice}; +use itertools::zip_eq; +use std::cmp::Ordering; use std::hash::Hash; +use std::ops::{Index, IndexMut}; +use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::OpType; +use hugr_core::types::Signature; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; -use super::value_row::ValueRow; -use super::{AbstractValue, DFContext}; +use super::partial_value::{AbstractValue, PartialValue}; +use super::DFContext; + type PV = super::partial_value::PartialValue; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -183,3 +187,111 @@ fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator impl Iterator + '_ { h.out_value_types(n).map(|x| x.0) } + +// Wrap a (known-length) row of values into a lattice. Perhaps could be part of partial_value.rs? + +#[derive(PartialEq, Clone, Eq, Hash)] +struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + pub fn iter(&self) -> impl Iterator> { + self.0.iter() + } + + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option> + '_> { + self[0] + .variant_values(variant, len) + .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) + } + + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 26acc31ed..dc3c7a69a 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -4,7 +4,6 @@ use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; use super::DFContext; -use super::ValueRow; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) @@ -19,7 +18,11 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { } impl> DFContext for T { - fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option> { + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>> { let op = self.get_optype(node); let sig = op.dataflow_signature()?; let known_ins = sig @@ -39,7 +42,7 @@ impl> DFContext for T { .collect::>(); let known_outs = self.interpret_leaf_op(node, &known_ins); (!known_outs.is_empty()).then(|| { - let mut res = ValueRow::new(sig.output_count()); + let mut res = vec![PartialValue::Bottom; sig.output_count()]; for (p, v) in known_outs { res[p.index()] = v.into(); } diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs deleted file mode 100644 index 9f7b8bef7..000000000 --- a/hugr-passes/src/dataflow/value_row.rs +++ /dev/null @@ -1,117 +0,0 @@ -// Really this is part of partial_value.rs - -use std::{ - cmp::Ordering, - ops::{Index, IndexMut}, -}; - -use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::zip_eq; - -use super::partial_value::{AbstractValue, PartialValue}; - -#[derive(PartialEq, Clone, Eq, Hash)] -pub struct ValueRow(Vec>); - -impl ValueRow { - pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) - } - - pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r - } - - pub fn iter(&self) -> impl Iterator> { - self.0.iter() - } - - pub fn unpack_first( - &self, - variant: usize, - len: usize, - ) -> Option> + '_> { - self[0] - .variant_values(variant, len) - .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) - } - - // fn initialised(&self) -> bool { - // self.0.iter().all(|x| x != &PV::top()) - // } -} - -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - -impl PartialOrd for ValueRow { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl Lattice for ValueRow { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - fn join_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.join_mut(v2); - } - changed - } - - fn meet_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.meet_mut(v2); - } - changed - } -} - -impl IntoIterator for ValueRow { - type Item = PartialValue; - - type IntoIter = > as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl Index for ValueRow -where - Vec>: Index, -{ - type Output = > as Index>::Output; - - fn index(&self, index: Idx) -> &Self::Output { - self.0.index(index) - } -} - -impl IndexMut for ValueRow -where - Vec>: IndexMut, -{ - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { - self.0.index_mut(index) - } -} From 15e642e7d03762d424536def8d6c0abaeaa70380 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 12:30:40 +0100 Subject: [PATCH 062/281] Remove PartialSum::unit() --- hugr-passes/src/dataflow/partial_value.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index e6b53b935..26de29c3f 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -29,9 +29,6 @@ pub enum ValueOrSum { pub struct PartialSum(pub HashMap>>); impl PartialSum { - pub fn unit() -> Self { - Self::variant(0, []) - } pub fn variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -481,7 +478,7 @@ mod test { }) .boxed() } - Self::Unit => Just(PartialSum::unit().into()).boxed(), + Self::Unit => Just(PartialValue::unit()).boxed(), } } } From 514af1307e6b84e6c166dd8dd58b958f09fda257 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 13:30:17 +0100 Subject: [PATCH 063/281] PartialValue is private struct containing PVEnum (with ::Sum not ::PartialSum) --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 153 ++++++++++++---------- hugr-passes/src/dataflow/total_context.rs | 3 +- 3 files changed, 84 insertions(+), 74 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 2e2e95936..15c08be04 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -6,7 +6,7 @@ mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PartialValue, ValueOrSum}; +pub use partial_value::{AbstractValue, PVEnum, PartialValue, ValueOrSum}; mod total_context; pub use total_context::TotalContext; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 26de29c3f..87400759d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -189,13 +189,23 @@ impl Hash for PartialSum { } } -/// We really must prevent people from constructing PartialValue::Value of -/// any `value` where `value.as_sum().is_some()`` #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub struct PartialValue(PVEnum); + +impl PartialValue { + /// Allows to read the enum, which guarantees that we never return [PVEnum::Value] + /// for a value whose [AbstractValue::as_sum] is `Some` - any such value will be + /// in the form of a [PVEnum::Sum] instead. + pub fn as_enum(&self) -> &PVEnum { + &self.0 + } +} + +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PVEnum { Bottom, Value(V), - PartialSum(PartialSum), + Sum(PartialSum), Top, } @@ -203,25 +213,25 @@ impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() .map(|(tag, values)| Self::variant(tag, values.map(Self::from))) - .unwrap_or(Self::Value(v)) + .unwrap_or(Self(PVEnum::Value(v))) } } impl From> for PartialValue { fn from(v: PartialSum) -> Self { - Self::PartialSum(v) + Self(PVEnum::Sum(v)) } } impl PartialValue { - // const BOTTOM: Self = Self::Bottom; + // const BOTTOM: Self = PVEnum::Bottom; // const BOTTOM_REF: &'static Self = &Self::BOTTOM; fn assert_invariants(&self) { - match self { - Self::PartialSum(ps) => { + match &self.0 { + PVEnum::Sum(ps) => { ps.assert_invariants(); } - Self::Value(v) => { + PVEnum::Value(v) => { assert!(v.as_sum().is_none()) } _ => {} @@ -242,36 +252,36 @@ impl PartialValue { } pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { - let vals = match self { - PartialValue::Bottom => return None, - PartialValue::Value(v) => { + let vals = match &self.0 { + PVEnum::Bottom => return None, + PVEnum::Value(v) => { assert!(v.as_sum().is_none()); return None; } - PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, - PartialValue::Top => vec![PartialValue::Top; len], + PVEnum::Sum(ps) => ps.variant_values(tag, len)?, + PVEnum::Top => vec![PartialValue(PVEnum::Top); len], }; assert_eq!(vals.len(), len); Some(vals) } pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom => false, - PartialValue::Value(v) => { + match &self.0 { + PVEnum::Bottom => false, + PVEnum::Value(v) => { assert!(v.as_sum().is_none()); false } - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, + PVEnum::Sum(ps) => ps.supports_tag(tag), + PVEnum::Top => true, } } pub fn try_into_value(self, typ: &Type) -> Result, Self> { - match self { - Self::Value(v) => Ok(ValueOrSum::Value(v.clone())), - Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), - x => Err(x), + match self.0 { + PVEnum::Value(v) => return Ok(ValueOrSum::Value(v.clone())), + PVEnum::Sum(ps) => ps.try_into_value(typ).map_err(Self::from), + _ => Err(self), } } } @@ -280,41 +290,40 @@ impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); - match (&*self, other) { - (Self::Top, _) => false, - (_, other @ Self::Top) => { - *self = other; + match (&self.0, other.0) { + (PVEnum::Top, _) => false, + (_, other @ PVEnum::Top) => { + self.0 = other; true } - (_, Self::Bottom) => false, - (Self::Bottom, other) => { - *self = other; + (_, PVEnum::Bottom) => false, + (PVEnum::Bottom, other) => { + self.0 = other; true } - (Self::Value(h1), Self::Value(h2)) => { + (PVEnum::Value(h1), PVEnum::Value(h2)) => { if h1 == &h2 { false } else { - *self = Self::Top; + self.0 = PVEnum::Top; true } } - (Self::PartialSum(_), Self::PartialSum(ps2)) => { - let Self::PartialSum(ps1) = self else { + (PVEnum::Sum(_), PVEnum::Sum(ps2)) => { + let Self(PVEnum::Sum(ps1)) = self else { unreachable!() }; match ps1.try_join_mut(ps2) { Ok(ch) => ch, Err(_) => { - *self = Self::Top; + self.0 = PVEnum::Top; true } } } - (Self::Value(ref v), Self::PartialSum(_)) - | (Self::PartialSum(_), Self::Value(ref v)) => { + (PVEnum::Value(ref v), PVEnum::Sum(_)) | (PVEnum::Sum(_), PVEnum::Value(ref v)) => { assert!(v.as_sum().is_none()); - *self = Self::Top; + self.0 = PVEnum::Top; true } } @@ -327,41 +336,41 @@ impl Lattice for PartialValue { fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - match (&*self, other) { - (Self::Bottom, _) => false, - (_, other @ Self::Bottom) => { - *self = other; + match (&self.0, other.0) { + (PVEnum::Bottom, _) => false, + (_, other @ PVEnum::Bottom) => { + self.0 = other; true } - (_, Self::Top) => false, - (Self::Top, other) => { - *self = other; + (_, PVEnum::Top) => false, + (PVEnum::Top, other) => { + self.0 = other; true } - (Self::Value(h1), Self::Value(h2)) => { + (PVEnum::Value(h1), PVEnum::Value(h2)) => { if h1 == &h2 { false } else { - *self = Self::Bottom; + self.0 = PVEnum::Bottom; true } } - (Self::PartialSum(_), Self::PartialSum(ps2)) => { - let Self::PartialSum(ps1) = self else { - unreachable!() + (PVEnum::Sum(_), PVEnum::Sum(ps2)) => { + let ps1 = match &mut self.0 { + PVEnum::Sum(ps1) => ps1, + _ => unreachable!(), }; match ps1.try_meet_mut(ps2) { Ok(ch) => ch, Err(_) => { - *self = Self::Bottom; + self.0 = PVEnum::Bottom; true } } } - (Self::Value(ref v), Self::PartialSum(_)) - | (Self::PartialSum(_), Self::Value(ref v)) => { + (PVEnum::Value(ref v), PVEnum::Sum(_)) | (PVEnum::Sum(_), PVEnum::Value(ref v)) => { assert!(v.as_sum().is_none()); - *self = Self::Bottom; + self.0 = PVEnum::Bottom; true } } @@ -370,26 +379,26 @@ impl Lattice for PartialValue { impl BoundedLattice for PartialValue { fn top() -> Self { - Self::Top + Self(PVEnum::Top) } fn bottom() -> Self { - Self::Bottom + Self(PVEnum::Bottom) } } impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; - match (self, other) { - (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), - (Self::Top, Self::Top) => Some(Ordering::Equal), - (Self::Bottom, _) => Some(Ordering::Less), - (_, Self::Bottom) => Some(Ordering::Greater), - (Self::Top, _) => Some(Ordering::Greater), - (_, Self::Top) => Some(Ordering::Less), - (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), - (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), + match (&self.0, &other.0) { + (PVEnum::Bottom, PVEnum::Bottom) => Some(Ordering::Equal), + (PVEnum::Top, PVEnum::Top) => Some(Ordering::Equal), + (PVEnum::Bottom, _) => Some(Ordering::Less), + (_, PVEnum::Bottom) => Some(Ordering::Greater), + (PVEnum::Top, _) => Some(Ordering::Greater), + (_, PVEnum::Top) => Some(Ordering::Less), + (PVEnum::Value(v1), PVEnum::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (PVEnum::Sum(ps1), PVEnum::Sum(ps2)) => ps1.partial_cmp(ps2), _ => None, } } @@ -408,7 +417,7 @@ mod test { types::{Type, TypeArg, TypeEnum}, }; - use super::{PartialSum, PartialValue}; + use super::{PVEnum, PartialSum, PartialValue}; use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; impl Arbitrary for ValueHandle { @@ -566,10 +575,10 @@ mod test { } fn type_check(&self, pv: &PartialValue) -> bool { - match (self, pv) { - (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), - (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + match (self, pv.as_enum()) { + (_, PVEnum::Bottom) | (_, PVEnum::Top) => true, + (_, PVEnum::Value(v)) => self.get_type() == v.get_type(), + (TestSumType::Branch(_, sop), PVEnum::Sum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { return false; @@ -584,7 +593,7 @@ mod test { } true } - (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), + (Self::Leaf(l), PVEnum::Sum(ps)) => l.type_check(ps), } } } diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index dc3c7a69a..cba3f08fe 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -1,5 +1,6 @@ use std::hash::Hash; +use ascent::lattice::BoundedLattice; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; @@ -42,7 +43,7 @@ impl> DFContext for T { .collect::>(); let known_outs = self.interpret_leaf_op(node, &known_ins); (!known_outs.is_empty()).then(|| { - let mut res = vec![PartialValue::Bottom; sig.output_count()]; + let mut res = vec![PartialValue::bottom(); sig.output_count()]; for (p, v) in known_outs { res[p.index()] = v.into(); } From 5d86f4669ef6d976b28331ab15c67197bafd138f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 13:38:48 +0100 Subject: [PATCH 064/281] variant => new_variant, unit => new_unit --- hugr-passes/src/dataflow/datalog.rs | 9 +++++---- hugr-passes/src/dataflow/partial_value.rs | 16 ++++++++-------- hugr-passes/src/dataflow/test.rs | 6 ++++-- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index c4c798324..b8f5e7230 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -153,15 +153,16 @@ fn propagate_leaf_op( // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be // easily available for reuse. - op if op.cast::().is_some() => { - Some(ValueRow::from_iter([PV::variant(0, ins.iter().cloned())])) - } + op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( + 0, + ins.iter().cloned(), + )])), op if op.cast::().is_some() => { let [tup] = ins.iter().collect::>().try_into().unwrap(); tup.variant_values(0, value_outputs(c.as_ref(), n).count()) .map(ValueRow::from_iter) } - OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( + OpType::Tag(t) => Some(ValueRow::from_iter([PV::new_variant( t.tag, ins.iter().cloned(), )])), diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 87400759d..2d3a83742 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -29,7 +29,7 @@ pub enum ValueOrSum { pub struct PartialSum(pub HashMap>>); impl PartialSum { - pub fn variant(tag: usize, values: impl IntoIterator>) -> Self { + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -212,7 +212,7 @@ pub enum PVEnum { impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() - .map(|(tag, values)| Self::variant(tag, values.map(Self::from))) + .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) .unwrap_or(Self(PVEnum::Value(v))) } } @@ -243,12 +243,12 @@ impl PartialValue { self } - pub fn variant(tag: usize, values: impl IntoIterator) -> Self { - PartialSum::variant(tag, values).into() + pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() } - pub fn unit() -> Self { - Self::variant(0, []) + pub fn new_unit() -> Self { + Self::new_variant(0, []) } pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { @@ -487,7 +487,7 @@ mod test { }) .boxed() } - Self::Unit => Just(PartialValue::unit()).boxed(), + Self::Unit => Just(PartialValue::new_unit()).boxed(), } } } @@ -677,7 +677,7 @@ mod test { ) }) .collect_vec(); - pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + pvs.prop_map(move |pvs| PartialValue::new_variant(index, pvs)) .boxed() } }) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e4b3d5c24..1d9668abf 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -221,8 +221,10 @@ fn conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - let arg_pv = - PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); + let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant( + 2, + [PartialValue::new_variant(0, [])], + )); machine.propolutate_out_wires([(arg_w, arg_pv)]); machine.run(HugrValueContext::new(&hugr)); From d8c8140b4b7e45b661579c5b27920e7a7217330f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 14:03:09 +0100 Subject: [PATCH 065/281] Simplify PartialOrd for PartialSum, keys(1,2) support cmp --- hugr-passes/src/dataflow/partial_value.rs | 30 ++++++++++------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 2d3a83742..1987eaf2d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -152,25 +152,21 @@ impl PartialOrd for PartialSum { keys2[*k] = 1; } - if let Some(ord) = keys1.partial_cmp(&keys2) { - if ord != Ordering::Equal { - return Some(ord); - } - } else { - return None; - } - for (k, lhs) in &self.0 { - let Some(rhs) = other.0.get(k) else { - unreachable!() - }; - match lhs.partial_cmp(rhs) { - Some(Ordering::Equal) => continue, - x => { - return x; + Some(match keys1.cmp(&keys2) { + ord @ Ordering::Greater | ord @ Ordering::Less => ord, + Ordering::Equal => { + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(k) else { + unreachable!() + }; + let key_cmp = lhs.partial_cmp(rhs); + if key_cmp != Some(Ordering::Equal) { + return key_cmp; + } } + Ordering::Equal } - } - Some(Ordering::Equal) + }) } } From 13156857fdcceebc552757391daa17ac87e4347b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 14:16:51 +0100 Subject: [PATCH 066/281] PartialSum::variant_values does not take `len` (PartialValue:: still does) --- hugr-passes/src/dataflow/partial_value.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 1987eaf2d..c2eff146e 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -133,10 +133,8 @@ impl PartialSum { } impl PartialSum { - pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { - let row = self.0.get(&variant)?; - assert!(row.len() == len); - Some(row.clone()) + pub fn variant_values(&self, variant: usize) -> Option>> { + self.0.get(&variant).cloned() } } @@ -254,7 +252,7 @@ impl PartialValue { assert!(v.as_sum().is_none()); return None; } - PVEnum::Sum(ps) => ps.variant_values(tag, len)?, + PVEnum::Sum(ps) => ps.variant_values(tag)?, PVEnum::Top => vec![PartialValue(PVEnum::Top); len], }; assert_eq!(vals.len(), len); From 2aaaeb9791526e9f2ec5371ea1f5349dd5f12a22 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 15:22:05 +0100 Subject: [PATCH 067/281] clippy --- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index c2eff146e..7a31478c4 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -118,7 +118,7 @@ impl PartialSum { if v.len() != r.len() { return Err(self); } - match zip_eq(v, r.into_iter()) + match zip_eq(v, r.iter()) .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { @@ -273,7 +273,7 @@ impl PartialValue { pub fn try_into_value(self, typ: &Type) -> Result, Self> { match self.0 { - PVEnum::Value(v) => return Ok(ValueOrSum::Value(v.clone())), + PVEnum::Value(v) => Ok(ValueOrSum::Value(v.clone())), PVEnum::Sum(ps) => ps.try_into_value(typ).map_err(Self::from), _ => Err(self), } From 5619761cdec00745c8aba786173078554cb0d762 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:41:02 +0100 Subject: [PATCH 068/281] Machine::tail_loop_terminates + case_reachable return Option not panic --- hugr-passes/src/dataflow/machine.rs | 32 +++++++++++++++++------------ hugr-passes/src/dataflow/test.rs | 26 ++++++++++++++--------- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index d4aa97e78..ad57328d1 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -51,27 +51,33 @@ impl> Machine { self.1.as_ref().unwrap().get(&w).cloned() } - pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { - assert!(hugr.get_optype(node).is_tail_loop()); + pub fn tail_loop_terminates( + &self, + hugr: impl HugrView, + node: Node, + ) -> Option { + hugr.get_optype(node).as_tail_loop()?; let [_, out] = hugr.get_io(node).unwrap(); - TailLoopTermination::from_control_value( + Some(TailLoopTermination::from_control_value( self.0 .in_wire_value .iter() .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) .unwrap(), - ) + )) } - pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { - assert!(hugr.get_optype(case).is_case()); - let cond = hugr.get_parent(case).unwrap(); - assert!(hugr.get_optype(cond).is_conditional()); - self.0 - .case_reachable - .iter() - .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) - .unwrap() + pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> Option { + hugr.get_optype(case).as_case()?; + let cond = hugr.get_parent(case)?; + hugr.get_optype(cond).as_conditional()?; + Some( + self.0 + .case_reachable + .iter() + .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) + .unwrap(), + ) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 1d9668abf..66b9285a4 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -6,11 +6,14 @@ use crate::{ use ascent::lattice::BoundedLattice; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, - extension::prelude::{UnpackTuple, BOOL_T}, - extension::{ExtensionSet, EMPTY_REG}, + extension::{ + prelude::{UnpackTuple, BOOL_T}, + ExtensionSet, EMPTY_REG, + }, ops::{handle::NodeHandle, OpTrait, Value}, type_row, types::{Signature, SumType, Type, TypeRow}, + HugrView, }; use super::partial_value::PartialValue; @@ -93,7 +96,7 @@ fn test_tail_loop_never_iterates() { let o_r = machine.read_out_wire_value(&hugr, tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( - TailLoopTermination::ExactlyZeroContinues, + Some(TailLoopTermination::ExactlyZeroContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) ) } @@ -125,9 +128,10 @@ fn test_tail_loop_always_iterates() { let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( - TailLoopTermination::Bottom, + Some(TailLoopTermination::Bottom), machine.tail_loop_terminates(&hugr, tail_loop.node()) - ) + ); + assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); } #[test] @@ -178,9 +182,10 @@ fn test_tail_loop_iterates_twice() { let _ = machine.read_out_wire_partial_value(o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( - TailLoopTermination::Top, + Some(TailLoopTermination::Top), machine.tail_loop_terminates(&hugr, tail_loop.node()) - ) + ); + assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); } #[test] @@ -232,7 +237,8 @@ fn conditional() { assert_eq!(cond_r1, Value::false_val()); assert!(machine.read_out_wire_value(&hugr, cond_o2).is_none()); - assert!(!machine.case_reachable(&hugr, case1.node())); // arg_pv is variant 1 or 2 only - assert!(machine.case_reachable(&hugr, case2.node())); - assert!(machine.case_reachable(&hugr, case3.node())); + assert_eq!(machine.case_reachable(&hugr, case1.node()), Some(false)); // arg_pv is variant 1 or 2 only + assert_eq!(machine.case_reachable(&hugr, case2.node()), Some(true)); + assert_eq!(machine.case_reachable(&hugr, case3.node()), Some(true)); + assert_eq!(machine.case_reachable(&hugr, cond.node()), None); } From f3c175c12e418bf83e0d44429a284691e8bea3b2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:46:37 +0100 Subject: [PATCH 069/281] Machine::read_out_wire_value fails with ConstTypeError if there was one --- hugr-passes/src/dataflow/machine.rs | 15 ++++++++++----- hugr-passes/src/dataflow/test.rs | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index ad57328d1..1b169abd2 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -85,16 +85,21 @@ impl> Machine where Value: From, { - pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { + pub fn read_out_wire_value( + &self, + hugr: impl HugrView, + w: Wire, + ) -> Result> { // dbg!(&w); - let pv = self.read_out_wire_partial_value(w)?; - // dbg!(&pv); let (_, typ) = hugr .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - let v: ValueOrSum = pv.try_into_value(&typ).ok()?; - v.try_into().ok() + let v = self + .read_out_wire_partial_value(w) + .and_then(|pv| pv.try_into_value(&typ).ok()) + .ok_or(None)?; + Ok(v.try_into().map_err(Some)?) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 66b9285a4..844aa47d5 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -235,7 +235,7 @@ fn conditional() { let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(machine.read_out_wire_value(&hugr, cond_o2).is_none()); + assert!(machine.read_out_wire_value(&hugr, cond_o2).is_err()); assert_eq!(machine.case_reachable(&hugr, case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(machine.case_reachable(&hugr, case2.node()), Some(true)); From aad2ef00d0474074f60d3f6cbd89cf828c0a52a0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:48:02 +0100 Subject: [PATCH 070/281] PartialValue::try_(join|meet)_mut are pub, don't mutate upon failure --- hugr-passes/src/dataflow/partial_value.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 7a31478c4..b5ccc6ade 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -46,16 +46,17 @@ impl PartialSum { } } - // Err with key if any common rows have different lengths (self may have been mutated) - fn try_join_mut(&mut self, other: Self) -> Result { + // Err with key if any common rows have different lengths (self not mutated) + pub fn try_join_mut(&mut self, other: Self) -> Result { + for (k, v) in &other.0 { + if self.0.get(k).is_some_and(|row| row.len() != v.len()) { + return Err(*k); + } + } let mut changed = false; for (k, v) in other.0 { if let Some(row) = self.0.get_mut(&k) { - if v.len() != row.len() { - // Better to check first and avoid mutation, but fine here - return Err(k); - } for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { changed |= lhs.join_mut(rhs); } @@ -68,7 +69,7 @@ impl PartialSum { } // Error with key if any common rows have different lengths ( => Bottom) - fn try_meet_mut(&mut self, other: Self) -> Result { + pub fn try_meet_mut(&mut self, other: Self) -> Result { let mut changed = false; let mut keys_to_remove = vec![]; for (k, v) in self.0.iter() { From 0a8cc12bdabe332047d4d21eecdc9fc5c4eea496 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:48:20 +0100 Subject: [PATCH 071/281] Remove some commented-out code --- hugr-passes/src/dataflow/partial_value.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index b5ccc6ade..fb79a93f6 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -219,8 +219,6 @@ impl From> for PartialValue { } impl PartialValue { - // const BOTTOM: Self = PVEnum::Bottom; - // const BOTTOM_REF: &'static Self = &Self::BOTTOM; fn assert_invariants(&self) { match &self.0 { PVEnum::Sum(ps) => { From bfcd0a675109dab997f3f2f5c750ba6c068c9d9a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:39:48 +0100 Subject: [PATCH 072/281] Expose PartialSum --- hugr-passes/src/dataflow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 15c08be04..d80ab275e 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -6,7 +6,7 @@ mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PVEnum, PartialValue, ValueOrSum}; +pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, ValueOrSum}; mod total_context; pub use total_context::TotalContext; From 248fb23eaacd880c1695ef13cc6653c3d29cab08 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:40:17 +0100 Subject: [PATCH 073/281] dataflow has docs! (enforced) --- hugr-passes/src/dataflow.rs | 5 ++ hugr-passes/src/dataflow/machine.rs | 44 +++++++++++++++-- hugr-passes/src/dataflow/partial_value.rs | 58 +++++++++++++++++++++-- hugr-passes/src/dataflow/total_context.rs | 8 +++- 4 files changed, 103 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index d80ab275e..6a8f94ebd 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,3 +1,4 @@ +#![warn(missing_docs)] //! Dataflow analysis of Hugrs. mod datalog; @@ -14,7 +15,11 @@ pub use total_context::TotalContext; use hugr_core::{Hugr, Node}; use std::hash::Hash; +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + /// Given lattice values for each input, produce lattice values for (what we know of) + /// the outputs. Returning `None` indicates nothing can be deduced. fn interpret_leaf_op( &self, node: Node, diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 1b169abd2..d15e49653 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -6,6 +6,11 @@ use super::{ datalog::AscentProgram, partial_value::ValueOrSum, AbstractValue, DFContext, PartialValue, }; +/// Basic structure for performing an analysis. Usage: +/// 1. Get a new instance via [Self::default()] +/// 2. Zero or more [Self::propolutate_out_wires] with initial values +/// 3. Exactly one [Self::run] to do the analysis +/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] pub struct Machine>( AscentProgram, Option>>, @@ -18,12 +23,9 @@ impl> Default for Machine { } } -/// Usage: -/// 1. Get a new instance via [Self::default()] -/// 2. Zero or more [Self::propolutate_out_wires] with initial values -/// 3. Exactly one [Self::run] to do the analysis -/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] impl> Machine { + /// Provide initial values for some wires. + /// (For example, if some properties of the Hugr's inputs are known.) pub fn propolutate_out_wires( &mut self, wires: impl IntoIterator)>, @@ -34,6 +36,13 @@ impl> Machine { .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); } + /// Run the analysis (iterate until a lattice fixpoint is reached). + /// The context passed in allows interpretation of leaf operations. + /// + /// # Panics + /// + /// If this Machine has been run already. + /// pub fn run(&mut self, context: C) { assert!(self.1.is_none()); self.0.context.push((context,)); @@ -47,10 +56,16 @@ impl> Machine { ) } + /// Gets the lattice value computed by [Self::run] for the given wire pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { self.1.as_ref().unwrap().get(&w).cloned() } + /// Tells whether a [TailLoop] node can terminate, i.e. whether + /// `Break` and/or `Continue` tags may be returned by the nested DFG. + /// Returns `None` if the specified `node` is not a [TailLoop]. + /// + /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates( &self, hugr: impl HugrView, @@ -67,6 +82,13 @@ impl> Machine { )) } + /// Tells whether a [Case] node is reachable, i.e. whether the predicate + /// to its parent [Conditional] may possibly have the tag corresponding to the [Case]. + /// Returns `None` if the specified `case` is not a [Case], or is not within a [Conditional] + /// (e.g. a [Case]-rooted Hugr). + /// + /// [Case]: hugr_core::ops::Case + /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> Option { hugr.get_optype(case).as_case()?; let cond = hugr.get_parent(case)?; @@ -85,6 +107,18 @@ impl> Machine where Value: From, { + /// Gets the Hugr [Value] computed by [Self::run] for the given wire, if possible. + /// (Only if the analysis determined a single `V`, or a Sum of `V`s with a single + /// possible tag, was present on that wire.) + /// + /// # Errors + /// `None` if the analysis did not result in a single [ValueOrSum] on that wire + /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] + /// + /// # Panics + /// If a [Type] for the specified wire could not be extracted from the Hugr + /// + /// [Type]: hugr_core::types::Type pub fn read_out_wire_value( &self, hugr: impl HugrView, diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index fb79a93f6..7a2e52f01 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -6,17 +6,31 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -/// Aka, deconstructible into Sum (TryIntoSum ?) +/// Trait for values which can be deconstructed into Sums (with a single known tag). +/// Required for values used in dataflow analysis. pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { - /// We write this way to optimize query/inspection (is-it-a-sum), + /// Deconstruct a value into a single known tag plus a row of values, if it is a [Sum]. + /// Note that one can just always return `None` but this will mean the analysis + /// is unable to understand untupling, and may give inconsistent results wrt. [Tag] + /// operations, etc. + /// + /// The signature is this way to optimize query/inspection (is-it-a-sum), /// at the cost of requiring more cloning during actual conversion /// (inside the lazy Iterator, or for the error case, as Self remains) + /// + /// [Sum]: TypeEnum::Sum + /// [Tag]: hugr_core::ops::Tag fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } +/// A struct returned from [PartialValue::try_into_value] and [PartialSum::try_into_value] +/// indicating the value is either a single value or a sum with a single known tag. #[derive(Clone, Debug, PartialEq, Eq)] pub enum ValueOrSum { + /// Single value in the domain `V` Value(V), + /// Sum with a single known Tag + #[allow(missing_docs)] Sum { tag: usize, items: Vec, @@ -24,15 +38,20 @@ pub enum ValueOrSum { }, } -// TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum +/// A representation of a value of [SumType], that may have one or more possible tags, +/// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(pub HashMap>>); impl PartialSum { + /// New instance for a single known tag. + /// (Multi-tag instances can be created via [Self::try_join_mut].) pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } + /// The number of possible variants we know about. (NOT the number + /// of tags possible for the value's type, whatever [SumType] that might be.) pub fn num_variants(&self) -> usize { self.0.len() } @@ -46,7 +65,10 @@ impl PartialSum { } } - // Err with key if any common rows have different lengths (self not mutated) + /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns + /// whether `self` has changed. + /// + /// Fails (without mutation) with the conflicting tag if any common rows have different lengths. pub fn try_join_mut(&mut self, other: Self) -> Result { for (k, v) in &other.0 { if self.0.get(k).is_some_and(|row| row.len() != v.len()) { @@ -68,7 +90,10 @@ impl PartialSum { Ok(changed) } - // Error with key if any common rows have different lengths ( => Bottom) + /// Mutates self according to lattice meet operation (towards `Bottom`). If successful, + /// returns whether `self` has changed. + /// + /// Fails (without mutation) with the conflicting tag if any common rows have different lengths pub fn try_meet_mut(&mut self, other: Self) -> Result { let mut changed = false; let mut keys_to_remove = vec![]; @@ -98,10 +123,14 @@ impl PartialSum { Ok(changed) } + /// Whether this sum might have the specified tag pub fn supports_tag(&self, tag: usize) -> bool { self.0.contains_key(&tag) } + /// Turns this instance into a [ValueOrSum::Sum] if it has exactly one possible tag, + /// otherwise failing and returning itself back unmodified (also if there is another + /// error, e.g. this instance is not described by `typ`). pub fn try_into_value(self, typ: &Type) -> Result, Self> { let Ok((k, v)) = self.0.iter().exactly_one() else { Err(self)? @@ -134,6 +163,7 @@ impl PartialSum { } impl PartialSum { + /// If this Sum might have the specified `tag`, get the elements inside that tag. pub fn variant_values(&self, variant: usize) -> Option>> { self.0.get(&variant).cloned() } @@ -184,6 +214,9 @@ impl Hash for PartialSum { } } +/// Wraps some underlying representation (knowledge) of values into a lattice +/// for use in dataflow analysis, including that an instance may be a [PartialSum] +/// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub struct PartialValue(PVEnum); @@ -196,11 +229,16 @@ impl PartialValue { } } +/// The contents of a [PartialValue], i.e. used as a view. #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub enum PVEnum { + /// No possibilities known (so far) Bottom, + /// A single value (of the underlying representation) Value(V), + /// Sum (with perhaps several possible tags) of underlying values Sum(PartialSum), + /// Might be more than one distinct value of the underlying type `V` Top, } @@ -231,19 +269,27 @@ impl PartialValue { } } + /// Computes the lattice-join (i.e. towards `Top`) of this [PartialValue] with another. pub fn join(mut self, other: Self) -> Self { self.join_mut(other); self } + /// New instance of a sum with a single known tag. pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { PartialSum::new_variant(tag, values).into() } + /// New instance of unit type (i.e. the only possible value, with no contents) pub fn new_unit() -> Self { Self::new_variant(0, []) } + /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. + /// + /// # Panics + /// + /// if the value is believed, for that tag, to have a number of values other than `len` pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match &self.0 { PVEnum::Bottom => return None, @@ -258,6 +304,7 @@ impl PartialValue { Some(vals) } + /// Tells us whether this value might be a Sum with the specified `tag` pub fn supports_tag(&self, tag: usize) -> bool { match &self.0 { PVEnum::Bottom => false, @@ -270,6 +317,7 @@ impl PartialValue { } } + /// Extracts a [ValueOrSum] if there is such a single representation pub fn try_into_value(self, typ: &Type) -> Result, Self> { match self.0 { PVEnum::Value(v) => Ok(ValueOrSum::Value(v.clone())), diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index cba3f08fe..262c250f9 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -7,10 +7,14 @@ use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; use super::DFContext; /// A simpler interface like [DFContext] but where the context only cares about -/// values that are completely known (in the lattice `V`) -/// rather than e.g. Sums potentially of two variants each of known values. +/// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or +/// Sums of potentially multiple variants. pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { + /// The representation of values on which [Self::interpret_leaf_op] operates type InterpretableVal: TryFrom>; + /// Interpret a leaf op. + /// `ins` gives the input ports for which we know (interpretable) values, and will be non-empty. + /// Returns a list of output ports for which we know (abstract) values (may be empty). fn interpret_leaf_op( &self, node: Node, From 5b8654e4673800fb6d0129635ea2480496b69fa6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 18:13:44 +0100 Subject: [PATCH 074/281] clippy --- hugr-passes/src/dataflow/machine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index d15e49653..834e9913b 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -133,7 +133,7 @@ where .read_out_wire_partial_value(w) .and_then(|pv| pv.try_into_value(&typ).ok()) .ok_or(None)?; - Ok(v.try_into().map_err(Some)?) + v.try_into().map_err(Some) } } From c40e718761e235df438835f05ada796c3c47b46e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 11:50:18 +0100 Subject: [PATCH 075/281] Move PartialValue::join into impl Lattice for --- hugr-passes/src/const_fold2.rs | 1 + hugr-passes/src/const_fold2/context.rs | 7 +++++++ hugr-passes/src/const_fold2/value_handle.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 11 +++++------ hugr-passes/src/dataflow/test.rs | 2 +- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 93b772d88..58f285d43 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,3 +1,4 @@ +#![warn(missing_docs)] //! An (example) use of the [super::dataflow](dataflow-analysis framework) //! to perform constant-folding. diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index c18f5430b..32fc57765 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -9,6 +9,10 @@ use super::value_handle::{ValueHandle, ValueKey}; use crate::dataflow::TotalContext; /// A context ([DFContext]) for doing analysis with [ValueHandle]s. +/// Interprets [LoadConstant](OpType::LoadConstant) nodes, +/// and [ExtensionOp](OpType::ExtensionOp) nodes where the extension does +/// (using [Value]s for extension-op inputs). +/// /// Just stores a Hugr (actually any [HugrView]), /// (there is )no state for operation-interpretation. /// @@ -17,6 +21,7 @@ use crate::dataflow::TotalContext; pub struct HugrValueContext(Arc); impl HugrValueContext { + /// Creates a new instance, given ownership of the [HugrView] pub fn new(hugr: H) -> Self { Self(Arc::new(hugr)) } @@ -30,6 +35,8 @@ impl Clone for HugrValueContext { } } +// Any value used in an Ascent program must be hashable. +// However, there should only be one DFContext, so its hash is immaterial. impl Hash for HugrValueContext { fn hash(&self, _state: &mut I) {} } diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 7b6e26106..bbcd25129 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -63,7 +63,7 @@ impl ValueKey { }) } - pub fn field(self, i: usize) -> Self { + fn field(self, i: usize) -> Self { Self::Field(i, Box::new(self)) } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 7a2e52f01..85b7ed395 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -269,12 +269,6 @@ impl PartialValue { } } - /// Computes the lattice-join (i.e. towards `Top`) of this [PartialValue] with another. - pub fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - /// New instance of a sum with a single known tag. pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { PartialSum::new_variant(tag, values).into() @@ -328,6 +322,11 @@ impl PartialValue { } impl Lattice for PartialValue { + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 844aa47d5..a7c90236f 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,7 +3,7 @@ use crate::{ dataflow::{machine::TailLoopTermination, Machine}, }; -use ascent::lattice::BoundedLattice; +use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ From 8732a637756d511e0a67a1a873e7e444f8293c74 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 12:17:59 +0100 Subject: [PATCH 076/281] Machine::read_out_wire_value => PartialValue::try_into_wire_value Is it worth keeping the ValueOrSum intermediate?? --- hugr-passes/src/dataflow/machine.rs | 36 +------------------ hugr-passes/src/dataflow/partial_value.rs | 35 +++++++++++++++++-- hugr-passes/src/dataflow/test.rs | 42 +++++++++++++++++++---- 3 files changed, 69 insertions(+), 44 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 834e9913b..533ac3a07 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -10,7 +10,7 @@ use super::{ /// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis -/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] +/// 4. Results then available via [Self::read_out_wire_partial_value] pub struct Machine>( AscentProgram, Option>>, @@ -103,40 +103,6 @@ impl> Machine { } } -impl> Machine -where - Value: From, -{ - /// Gets the Hugr [Value] computed by [Self::run] for the given wire, if possible. - /// (Only if the analysis determined a single `V`, or a Sum of `V`s with a single - /// possible tag, was present on that wire.) - /// - /// # Errors - /// `None` if the analysis did not result in a single [ValueOrSum] on that wire - /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] - /// - /// # Panics - /// If a [Type] for the specified wire could not be extracted from the Hugr - /// - /// [Type]: hugr_core::types::Type - pub fn read_out_wire_value( - &self, - hugr: impl HugrView, - w: Wire, - ) -> Result> { - // dbg!(&w); - let (_, typ) = hugr - .out_value_types(w.node()) - .find(|(p, _)| *p == w.source()) - .unwrap(); - let v = self - .read_out_wire_partial_value(w) - .and_then(|pv| pv.try_into_value(&typ).ok()) - .ok_or(None)?; - v.try_into().map_err(Some) - } -} - impl TryFrom> for Value where Value: From, diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 85b7ed395..727cde174 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,6 +1,8 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::types::{SumType, Type, TypeEnum, TypeRow}; +use hugr_core::ops::Value; +use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::{HugrView, Wire}; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -311,7 +313,8 @@ impl PartialValue { } } - /// Extracts a [ValueOrSum] if there is such a single representation + /// Extracts a [ValueOrSum] if there is such a single representation, + /// given a [Type] pub fn try_into_value(self, typ: &Type) -> Result, Self> { match self.0 { PVEnum::Value(v) => Ok(ValueOrSum::Value(v.clone())), @@ -321,6 +324,34 @@ impl PartialValue { } } +impl PartialValue +where + Value: From, +{ + /// Extracts a [ValueOrSum] if there is such a single representation, + /// given a HugrView and Wire that determine the type. + /// + /// # Errors + /// `None` if the analysis did not result in a single [ValueOrSum] on that wire + /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] + /// + /// # Panics + /// + /// If a [Type] for the specified wire could not be extracted from the Hugr + pub fn try_into_wire_value( + self, + hugr: &impl HugrView, + w: Wire, + ) -> Result> { + let (_, typ) = hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + let vs = self.try_into_value(&typ).map_err(|_| None)?; + vs.try_into().map_err(Some) + } +} + impl Lattice for PartialValue { fn join(mut self, other: Self) -> Self { self.join_mut(other); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index a7c90236f..fcca86b1b 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -29,7 +29,11 @@ fn test_make_tuple() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - let x = machine.read_out_wire_value(&hugr, v3).unwrap(); + let x = machine + .read_out_wire_partial_value(v3) + .unwrap() + .try_into_wire_value(&hugr, v3) + .unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); } @@ -46,9 +50,17 @@ fn test_unpack_tuple_const() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); + let o1_r = machine + .read_out_wire_partial_value(o1) + .unwrap() + .try_into_wire_value(&hugr, o1) + .unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r = machine.read_out_wire_value(&hugr, o2).unwrap(); + let o2_r = machine + .read_out_wire_partial_value(o2) + .unwrap() + .try_into_wire_value(&hugr, o2) + .unwrap(); assert_eq!(o2_r, Value::true_val()); } @@ -65,7 +77,11 @@ fn test_unpack_const() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); + let o_r = machine + .read_out_wire_partial_value(o) + .unwrap() + .try_into_wire_value(&hugr, o) + .unwrap(); assert_eq!(o_r, Value::true_val()); } @@ -93,7 +109,11 @@ fn test_tail_loop_never_iterates() { // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); - let o_r = machine.read_out_wire_value(&hugr, tl_o).unwrap(); + let o_r = machine + .read_out_wire_partial_value(tl_o) + .unwrap() + .try_into_wire_value(&hugr, tl_o) + .unwrap(); assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::ExactlyZeroContinues), @@ -233,9 +253,17 @@ fn conditional() { machine.propolutate_out_wires([(arg_w, arg_pv)]); machine.run(HugrValueContext::new(&hugr)); - let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); + let cond_r1 = machine + .read_out_wire_partial_value(cond_o1) + .unwrap() + .try_into_wire_value(&hugr, cond_o1) + .unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(machine.read_out_wire_value(&hugr, cond_o2).is_err()); + assert!(machine + .read_out_wire_partial_value(cond_o2) + .unwrap() + .try_into_wire_value(&hugr, cond_o2) + .is_err()); assert_eq!(machine.case_reachable(&hugr, case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(machine.case_reachable(&hugr, case2.node()), Some(true)); From 346187d7cd64ae2d84514c177f32727554fe8868 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 12:19:32 +0100 Subject: [PATCH 077/281] read_out_wire_partial_value => read_out_wire --- hugr-passes/src/dataflow/machine.rs | 4 ++-- hugr-passes/src/dataflow/test.rs | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 533ac3a07..f1b685fd8 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -10,7 +10,7 @@ use super::{ /// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis -/// 4. Results then available via [Self::read_out_wire_partial_value] +/// 4. Results then available via [Self::read_out_wire] pub struct Machine>( AscentProgram, Option>>, @@ -57,7 +57,7 @@ impl> Machine { } /// Gets the lattice value computed by [Self::run] for the given wire - pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.1.as_ref().unwrap().get(&w).cloned() } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index fcca86b1b..cfc9b8975 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -30,7 +30,7 @@ fn test_make_tuple() { machine.run(HugrValueContext::new(&hugr)); let x = machine - .read_out_wire_partial_value(v3) + .read_out_wire(v3) .unwrap() .try_into_wire_value(&hugr, v3) .unwrap(); @@ -51,13 +51,13 @@ fn test_unpack_tuple_const() { machine.run(HugrValueContext::new(&hugr)); let o1_r = machine - .read_out_wire_partial_value(o1) + .read_out_wire(o1) .unwrap() .try_into_wire_value(&hugr, o1) .unwrap(); assert_eq!(o1_r, Value::false_val()); let o2_r = machine - .read_out_wire_partial_value(o2) + .read_out_wire(o2) .unwrap() .try_into_wire_value(&hugr, o2) .unwrap(); @@ -78,7 +78,7 @@ fn test_unpack_const() { machine.run(HugrValueContext::new(&hugr)); let o_r = machine - .read_out_wire_partial_value(o) + .read_out_wire(o) .unwrap() .try_into_wire_value(&hugr, o) .unwrap(); @@ -110,7 +110,7 @@ fn test_tail_loop_never_iterates() { // dbg!(&machine.out_wire_value); let o_r = machine - .read_out_wire_partial_value(tl_o) + .read_out_wire(tl_o) .unwrap() .try_into_wire_value(&hugr, tl_o) .unwrap(); @@ -143,9 +143,9 @@ fn test_tail_loop_always_iterates() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); + let o_r1 = machine.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); - let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); + let o_r2 = machine.read_out_wire(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( Some(TailLoopTermination::Bottom), @@ -197,9 +197,9 @@ fn test_tail_loop_iterates_twice() { // dbg!(&machine.out_wire_value); // TODO these hould be the propagated values for now they will bt join(true,false) - let _ = machine.read_out_wire_partial_value(o_w1).unwrap(); + let _ = machine.read_out_wire(o_w1).unwrap(); // assert_eq!(o_r1, PartialValue::top()); - let _ = machine.read_out_wire_partial_value(o_w2).unwrap(); + let _ = machine.read_out_wire(o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( Some(TailLoopTermination::Top), @@ -254,13 +254,13 @@ fn conditional() { machine.run(HugrValueContext::new(&hugr)); let cond_r1 = machine - .read_out_wire_partial_value(cond_o1) + .read_out_wire(cond_o1) .unwrap() .try_into_wire_value(&hugr, cond_o1) .unwrap(); assert_eq!(cond_r1, Value::false_val()); assert!(machine - .read_out_wire_partial_value(cond_o2) + .read_out_wire(cond_o2) .unwrap() .try_into_wire_value(&hugr, cond_o2) .is_err()); From a139f9e0e78eccb49c45f7fa5e8b4f82b46f82d6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 12:53:16 +0100 Subject: [PATCH 078/281] Remove ValueOrSum (and add Sum) via complex parametrization of try_into_value --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 25 +------- hugr-passes/src/dataflow/partial_value.rs | 70 ++++++++++++++--------- hugr-passes/src/dataflow/total_context.rs | 9 +-- 4 files changed, 49 insertions(+), 57 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 6a8f94ebd..f7c7555fa 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -7,7 +7,7 @@ mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, ValueOrSum}; +pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum}; mod total_context; pub use total_context::TotalContext; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index f1b685fd8..acd3ac1ca 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,10 +1,8 @@ use std::collections::HashMap; -use hugr_core::{ops::Value, types::ConstTypeError, HugrView, Node, PortIndex, Wire}; +use hugr_core::{HugrView, Node, PortIndex, Wire}; -use super::{ - datalog::AscentProgram, partial_value::ValueOrSum, AbstractValue, DFContext, PartialValue, -}; +use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] @@ -103,25 +101,6 @@ impl> Machine { } } -impl TryFrom> for Value -where - Value: From, -{ - type Error = ConstTypeError; - fn try_from(value: ValueOrSum) -> Result { - match value { - ValueOrSum::Value(v) => Ok(v.into()), - ValueOrSum::Sum { tag, items, st } => { - let items = items - .into_iter() - .map(Value::try_from) - .collect::, _>>()?; - Value::sum(tag, items, st.clone()) - } - } - } -} - #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub enum TailLoopTermination { Bottom, diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 727cde174..4eef5787d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -25,19 +25,18 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } -/// A struct returned from [PartialValue::try_into_value] and [PartialSum::try_into_value] -/// indicating the value is either a single value or a sum with a single known tag. +/// Represents a sum with a single/known tag, abstracted over the representation of the elements. +/// (Identical to [Sum](hugr_core::ops::constant::Sum) except for the type abstraction.) #[derive(Clone, Debug, PartialEq, Eq)] -pub enum ValueOrSum { - /// Single value in the domain `V` - Value(V), - /// Sum with a single known Tag - #[allow(missing_docs)] - Sum { - tag: usize, - items: Vec, - st: SumType, - }, +pub struct Sum { + /// The tag index of the variant. + pub tag: usize, + /// The value of the variant. + /// + /// Sum variants are always a row of values, hence the Vec. + pub values: Vec, + /// The full type of the Sum, including the other variants. + pub st: SumType, } /// A representation of a value of [SumType], that may have one or more possible tags, @@ -130,10 +129,14 @@ impl PartialSum { self.0.contains_key(&tag) } - /// Turns this instance into a [ValueOrSum::Sum] if it has exactly one possible tag, + /// Turns this instance into a [Sum] if it has exactly one possible tag, /// otherwise failing and returning itself back unmodified (also if there is another /// error, e.g. this instance is not described by `typ`). - pub fn try_into_value(self, typ: &Type) -> Result, Self> { + // ALAN is this too parametric? Should we fix V2 == Value? Is the 'Self' error useful (no?) + pub fn try_into_value + TryFrom>>( + self, + typ: &Type, + ) -> Result, Self> { let Ok((k, v)) = self.0.iter().exactly_one() else { Err(self)? }; @@ -154,9 +157,9 @@ impl PartialSum { .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { - Ok(vs) => Ok(ValueOrSum::Sum { + Ok(values) => Ok(Sum { tag: *k, - items: vs, + values, st: st.clone(), }), Err(_) => Err(self), @@ -313,26 +316,40 @@ impl PartialValue { } } - /// Extracts a [ValueOrSum] if there is such a single representation, - /// given a [Type] - pub fn try_into_value(self, typ: &Type) -> Result, Self> { + /// Extracts a value (in any representation supporting both leaf values and sums) + // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? + pub fn try_into_value + TryFrom>>( + self, + typ: &Type, + ) -> Result>>::Error>> { match self.0 { - PVEnum::Value(v) => Ok(ValueOrSum::Value(v.clone())), - PVEnum::Sum(ps) => ps.try_into_value(typ).map_err(Self::from), - _ => Err(self), + PVEnum::Value(v) => Ok(V2::from(v.clone())), + PVEnum::Sum(ps) => { + let v = ps.try_into_value(typ).map_err(|_| None)?; + V2::try_from(v).map_err(Some) + } + _ => Err(None), } } } +impl TryFrom> for Value { + type Error = ConstTypeError; + + fn try_from(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } +} + impl PartialValue where Value: From, { - /// Extracts a [ValueOrSum] if there is such a single representation, - /// given a HugrView and Wire that determine the type. + /// Turns this instance into a [Value], if it is either a single [value](PVEnum::Value) or + /// a [sum](PVEnum::Sum) with a single known tag, extracting the desired type from a HugrView and Wire. /// /// # Errors - /// `None` if the analysis did not result in a single [ValueOrSum] on that wire + /// `None` if the analysis did not result in a single value on that wire /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] /// /// # Panics @@ -347,8 +364,7 @@ where .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - let vs = self.try_into_value(&typ).map_err(|_| None)?; - vs.try_into().map_err(Some) + self.try_into_value(&typ) } } diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 262c250f9..2326d78cd 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use ascent::lattice::BoundedLattice; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; +use super::partial_value::{AbstractValue, PartialValue, Sum}; use super::DFContext; /// A simpler interface like [DFContext] but where the context only cares about @@ -11,7 +11,7 @@ use super::DFContext; /// Sums of potentially multiple variants. pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { /// The representation of values on which [Self::interpret_leaf_op] operates - type InterpretableVal: TryFrom>; + type InterpretableVal: From + TryFrom>; /// Interpret a leaf op. /// `ins` gives the input ports for which we know (interpretable) values, and will be non-empty. /// Returns a list of output ports for which we know (abstract) values (may be empty). @@ -37,11 +37,8 @@ impl> DFContext for T { .zip(ins.iter()) .filter_map(|((i, ty), pv)| { pv.clone() - .try_into_value(ty) - // Discard PVs which don't produce ValueOrSum, i.e. Bottom/Top :-) + .try_into_value::<>::InterpretableVal>(ty) .ok() - // And discard any ValueOrSum that don't produce V - this is a bit silent :-( - .and_then(|v_s| T::InterpretableVal::try_from(v_s).ok()) .map(|v| (IncomingPort::from(i), v)) }) .collect::>(); From 7f2a91a5fc5bc26143f5e82543c172e10ebea90d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 17:18:35 +0100 Subject: [PATCH 079/281] Datalog works on any AbstractValue; impl'd by PartialValue for a BaseValue --- hugr-passes/src/const_fold2/value_handle.rs | 4 +- hugr-passes/src/dataflow.rs | 18 +---- hugr-passes/src/dataflow/datalog.rs | 90 +++++++++++++-------- hugr-passes/src/dataflow/machine.rs | 19 ++--- hugr-passes/src/dataflow/partial_value.rs | 79 +++++++++--------- hugr-passes/src/dataflow/test.rs | 2 +- hugr-passes/src/dataflow/total_context.rs | 6 +- 7 files changed, 114 insertions(+), 104 deletions(-) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index bbcd25129..59a08b50a 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -7,7 +7,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use crate::dataflow::AbstractValue; +use crate::dataflow::BaseValue; #[derive(Clone, Debug)] pub struct HashedConst { @@ -85,7 +85,7 @@ impl ValueHandle { } } -impl AbstractValue for ValueHandle { +impl BaseValue for ValueHandle { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { match self.value() { Value::Sum(Sum { tag, values, .. }) => Some(( diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f7c7555fa..a66edde03 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,30 +2,16 @@ //! Dataflow analysis of Hugrs. mod datalog; +pub use datalog::{AbstractValue, DFContext}; mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum}; +pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; mod total_context; pub use total_context::TotalContext; -use hugr_core::{Hugr, Node}; -use std::hash::Hash; - -/// Clients of the dataflow framework (particular analyses, such as constant folding) -/// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - /// Given lattice values for each input, produce lattice values for (what we know of) - /// the outputs. Returning `None` indicates nothing can be deduced. - fn interpret_leaf_op( - &self, - node: Node, - ins: &[PartialValue], - ) -> Option>>; -} - #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b8f5e7230..fcde4f96b 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -16,12 +16,7 @@ use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::OpType; use hugr_core::types::Signature; -use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; - -use super::partial_value::{AbstractValue, PartialValue}; -use super::DFContext; - -type PV = super::partial_value::PartialValue; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IO { @@ -29,19 +24,50 @@ pub enum IO { Output, } +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `PV`). +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + /// Given lattice values for each input, produce lattice values for (what we know of) + /// the outputs. Returning `None` indicates nothing can be deduced. + fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option>; +} + +/// Values which can be the domain for dataflow analysis. Must be able to deconstructed +/// into (and constructed from) Sums as these determine control flow. +pub trait AbstractValue: BoundedLattice + Clone + Eq + Hash + std::fmt::Debug { + /// Create a new instance representing a Sum with a single known tag + /// and (recursive) representations of the elements within that tag. + fn new_variant(tag: usize, values: impl IntoIterator) -> Self; + + /// New instance of unit type (i.e. the only possible value, with no contents) + fn new_unit() -> Self { + Self::new_variant(0, []) + } + + /// Test whether this value *might* be a Sum with the specified tag. + fn supports_tag(&self, tag: usize) -> bool; + + /// If this value might be a Sum with the specified tag, return values + /// describing the elements of the Sum, otherwise `None`. + /// + /// Implementations must hold the invariant that for all `x`, `tag` and `len`: + /// `x.variant_values(tag, len).is_some() == x.supports_tag(tag)` + fn variant_values(&self, tag: usize, len: usize) -> Option>; +} + ascent::ascent! { - pub(super) struct AscentProgram>; + pub(super) struct AscentProgram>; relation context(C); - relation out_wire_value_proto(Node, OutgoingPort, PV); + relation out_wire_value_proto(Node, OutgoingPort, PV); relation node(C, Node); relation in_wire(C, Node, IncomingPort); relation out_wire(C, Node, OutgoingPort); relation parent_of_node(C, Node, Node); relation io_node(C, Node, Node, IO); - lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); - lattice in_wire_value(C, Node, IncomingPort, PV); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); node(c, n) <-- context(c), for n in c.nodes(); @@ -144,11 +170,11 @@ ascent::ascent! { } -fn propagate_leaf_op( - c: &impl DFContext, +fn propagate_leaf_op( + c: &impl DFContext, n: Node, - ins: &[PV], -) -> Option> { + ins: &[PV], +) -> Option> { match c.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be @@ -192,21 +218,21 @@ fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator(Vec>); +struct ValueRow(Vec); -impl ValueRow { +impl ValueRow { pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) + Self(vec![PV::bottom(); len]) } - pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + pub fn single_known(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - pub fn iter(&self) -> impl Iterator> { + pub fn iter(&self) -> impl Iterator { self.0.iter() } @@ -214,7 +240,7 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option> + '_> { + ) -> Option + '_> { self[0] .variant_values(variant, len) .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) @@ -225,13 +251,13 @@ impl ValueRow { // } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator for ValueRow { + fn from_iter>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } @@ -267,30 +293,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PartialValue; +impl IntoIterator for ValueRow { + type Item = PV; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec>: Index, + Vec: Index, { - type Output = > as Index>::Output; + type Output = as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec>: IndexMut, + Vec: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index acd3ac1ca..986fafa76 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -2,16 +2,16 @@ use std::collections::HashMap; use hugr_core::{HugrView, Node, PortIndex, Wire}; -use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; +use super::{datalog::AscentProgram, AbstractValue, DFContext}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire] -pub struct Machine>( - AscentProgram, - Option>>, +pub struct Machine>( + AscentProgram, + Option>, ); /// derived-Default requires the context to be Defaultable, which is unnecessary @@ -21,13 +21,10 @@ impl> Default for Machine { } } -impl> Machine { +impl> Machine { /// Provide initial values for some wires. /// (For example, if some properties of the Hugr's inputs are known.) - pub fn propolutate_out_wires( - &mut self, - wires: impl IntoIterator)>, - ) { + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { assert!(self.1.is_none()); self.0 .out_wire_value_proto @@ -55,7 +52,7 @@ impl> Machine { } /// Gets the lattice value computed by [Self::run] for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option { self.1.as_ref().unwrap().get(&w).cloned() } @@ -109,7 +106,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - pub fn from_control_value(v: &PartialValue) -> Self { + pub fn from_control_value(v: &impl AbstractValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break && !may_continue { Self::ExactlyZeroContinues diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 4eef5787d..e985eceab 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -8,10 +8,12 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -/// Trait for values which can be deconstructed into Sums (with a single known tag). -/// Required for values used in dataflow analysis. -pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { - /// Deconstruct a value into a single known tag plus a row of values, if it is a [Sum]. +use super::AbstractValue; + +/// Trait for abstract values that may represent sums. +/// Can be wrapped into an [AbstractValue] for analysis via [PartialValue]. +pub trait BaseValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + /// Deconstruct an abstract value into a single known tag plus a row of values, if it is a [Sum]. /// Note that one can just always return `None` but this will mean the analysis /// is unable to understand untupling, and may give inconsistent results wrt. [Tag] /// operations, etc. @@ -58,7 +60,7 @@ impl PartialSum { } } -impl PartialSum { +impl PartialSum { fn assert_invariants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { @@ -219,15 +221,15 @@ impl Hash for PartialSum { } } -/// Wraps some underlying representation (knowledge) of values into a lattice -/// for use in dataflow analysis, including that an instance may be a [PartialSum] -/// of values of the underlying representation +/// Wraps some underlying representation of values (that `impl`s [BaseValue]) into +/// a lattice for use in dataflow analysis, including that an instance may be +/// a [PartialSum] of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub struct PartialValue(PVEnum); impl PartialValue { /// Allows to read the enum, which guarantees that we never return [PVEnum::Value] - /// for a value whose [AbstractValue::as_sum] is `Some` - any such value will be + /// for a value whose [BaseValue::as_sum] is `Some` - any such value will be /// in the form of a [PVEnum::Sum] instead. pub fn as_enum(&self) -> &PVEnum { &self.0 @@ -247,7 +249,7 @@ pub enum PVEnum { Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) @@ -261,7 +263,7 @@ impl From> for PartialValue { } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { match &self.0 { PVEnum::Sum(ps) => { @@ -274,22 +276,30 @@ impl PartialValue { } } - /// New instance of a sum with a single known tag. - pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { - PartialSum::new_variant(tag, values).into() - } - - /// New instance of unit type (i.e. the only possible value, with no contents) - pub fn new_unit() -> Self { - Self::new_variant(0, []) + /// Extracts a value (in any representation supporting both leaf values and sums) + // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? + pub fn try_into_value + TryFrom>>( + self, + typ: &Type, + ) -> Result>>::Error>> { + match self.0 { + PVEnum::Value(v) => Ok(V2::from(v.clone())), + PVEnum::Sum(ps) => { + let v = ps.try_into_value(typ).map_err(|_| None)?; + V2::try_from(v).map_err(Some) + } + _ => Err(None), + } } +} +impl AbstractValue for PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match &self.0 { PVEnum::Bottom => return None, PVEnum::Value(v) => { @@ -304,7 +314,7 @@ impl PartialValue { } /// Tells us whether this value might be a Sum with the specified `tag` - pub fn supports_tag(&self, tag: usize) -> bool { + fn supports_tag(&self, tag: usize) -> bool { match &self.0 { PVEnum::Bottom => false, PVEnum::Value(v) => { @@ -316,20 +326,8 @@ impl PartialValue { } } - /// Extracts a value (in any representation supporting both leaf values and sums) - // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? - pub fn try_into_value + TryFrom>>( - self, - typ: &Type, - ) -> Result>>::Error>> { - match self.0 { - PVEnum::Value(v) => Ok(V2::from(v.clone())), - PVEnum::Sum(ps) => { - let v = ps.try_into_value(typ).map_err(|_| None)?; - V2::try_from(v).map_err(Some) - } - _ => Err(None), - } + fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() } } @@ -341,7 +339,7 @@ impl TryFrom> for Value { } } -impl PartialValue +impl PartialValue where Value: From, { @@ -368,7 +366,7 @@ where } } -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join(mut self, other: Self) -> Self { self.join_mut(other); self @@ -464,7 +462,7 @@ impl Lattice for PartialValue { } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self(PVEnum::Top) } @@ -505,7 +503,10 @@ mod test { }; use super::{PVEnum, PartialSum, PartialValue}; - use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; + use crate::{ + const_fold2::value_handle::{ValueHandle, ValueKey}, + dataflow::AbstractValue, + }; impl Arbitrary for ValueHandle { type Parameters = (); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index cfc9b8975..127dcc373 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,6 +1,6 @@ use crate::{ const_fold2::HugrValueContext, - dataflow::{machine::TailLoopTermination, Machine}, + dataflow::{machine::TailLoopTermination, AbstractValue, Machine}, }; use ascent::{lattice::BoundedLattice, Lattice}; diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 2326d78cd..d512912d0 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -3,8 +3,8 @@ use std::hash::Hash; use ascent::lattice::BoundedLattice; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::partial_value::{AbstractValue, PartialValue, Sum}; -use super::DFContext; +use super::partial_value::{PartialValue, Sum}; +use super::{BaseValue, DFContext}; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or @@ -22,7 +22,7 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { ) -> Vec<(OutgoingPort, V)>; } -impl> DFContext for T { +impl> DFContext> for T { fn interpret_leaf_op( &self, node: Node, From 1680829179d4223bc728e43b63b60759852bfdd6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 16:57:28 +0100 Subject: [PATCH 080/281] PartialValue proptests: rm TestSumLeafType, replace ValueHandle with TestValue --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/dataflow/partial_value.rs | 386 ++++++++-------------- 2 files changed, 134 insertions(+), 253 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 4234f7f95..8e68e9ad7 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -31,3 +31,4 @@ extension_inference = ["hugr-core/extension_inference"] rstest = { workspace = true } proptest = { workspace = true } proptest-derive = { workspace = true } +proptest-recurse = { version = "0.5.0" } \ No newline at end of file diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index e985eceab..272e43f2a 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -10,21 +10,26 @@ use std::hash::{Hash, Hasher}; use super::AbstractValue; -/// Trait for abstract values that may represent sums. -/// Can be wrapped into an [AbstractValue] for analysis via [PartialValue]. +/// Trait for abstract values that can be wrapped by [PartialValue] for dataflow analysis. +/// (Allows the values to represent sums, but does not require this). pub trait BaseValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { - /// Deconstruct an abstract value into a single known tag plus a row of values, if it is a [Sum]. - /// Note that one can just always return `None` but this will mean the analysis - /// is unable to understand untupling, and may give inconsistent results wrt. [Tag] - /// operations, etc. + /// If the abstract value represents a [Sum] with a single known tag, deconstruct it + /// into that tag plus the elements. The default just returns `None` which is + /// appropriate if the abstract value never does (in which case [interpret_leaf_op] + /// must produce a [PartialValue::new_variant] for any operation producing + /// a sum). /// /// The signature is this way to optimize query/inspection (is-it-a-sum), /// at the cost of requiring more cloning during actual conversion /// (inside the lazy Iterator, or for the error case, as Self remains) /// + /// [interpret_leaf_op]: super::DFContext::interpret_leaf_op /// [Sum]: TypeEnum::Sum /// [Tag]: hugr_core::ops::Tag - fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; + fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { + let res: Option<(usize, as IntoIterator>::IntoIter)> = None; + res + } } /// Represents a sum with a single/known tag, abstracted over the representation of the elements. @@ -494,179 +499,51 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; - use itertools::{zip_eq, Either, Itertools as _}; + use itertools::{zip_eq, Itertools as _}; + use prop::sample::subsequence; use proptest::prelude::*; - use hugr_core::{ - std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, - types::{Type, TypeArg, TypeEnum}, - }; - - use super::{PVEnum, PartialSum, PartialValue}; - use crate::{ - const_fold2::value_handle::{ValueHandle, ValueKey}, - dataflow::AbstractValue, - }; - - impl Arbitrary for ValueHandle { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - // prop_oneof![ - - // ] - todo!() - } - } + use proptest_recurse::{StrategyExt, StrategySet}; - #[derive(Debug, PartialEq, Eq, Clone)] - enum TestSumLeafType { - Int(Type), - Unit, - } - - impl TestSumLeafType { - fn assert_valid(&self) { - if let Self::Int(t) = self { - if let TypeEnum::Extension(ct) = t.as_type_enum() { - assert_eq!("int", ct.name()); - assert_eq!(&int_types::EXTENSION_ID, ct.extension()); - } else { - panic!("Expected int type, got {:#?}", t); - } - } - } - - fn get_type(&self) -> Type { - match self { - Self::Int(t) => t.clone(), - Self::Unit => Type::UNIT, - } - } - - fn type_check(&self, ps: &PartialSum) -> bool { - match self { - Self::Int(_) => false, - Self::Unit => { - if let Ok((0, v)) = ps.0.iter().exactly_one() { - v.is_empty() - } else { - false - } - } - } - } - - fn partial_value_strategy(self) -> impl Strategy> { - match self { - Self::Int(t) => { - let TypeEnum::Extension(ct) = t.as_type_enum() else { - unreachable!() - }; - // TODO this should be get_log_width, but that's not pub - let TypeArg::BoundedNat { n: lw } = ct.args()[0] else { - panic!() - }; - (0u64..(1 << (2u64.pow(lw as u32) - 1))) - .prop_map(move |x| { - let ki = ConstInt::new_u(lw as u8, x).unwrap(); - let k = ValueKey::try_new(ki.clone()).unwrap(); - ValueHandle::new(k, Arc::new(ki.into())).into() - }) - .boxed() - } - Self::Unit => Just(PartialValue::new_unit()).boxed(), - } - } - } - - impl Arbitrary for TestSumLeafType { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - let int_strat = - (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); - prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() - } - } + use super::{BaseValue, PVEnum, PartialSum, PartialValue}; + use crate::dataflow::AbstractValue; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { - Branch(usize, Vec>>), - Leaf(TestSumLeafType), + Branch(Vec>>), + /// None => unit, Some => TestValue <= this *usize* + Leaf(Option), } - impl TestSumType { - #[allow(unused)] // ALAN ? - fn leaf(v: Type) -> Self { - TestSumType::Leaf(TestSumLeafType::Int(v)) - } - - fn branch(vs: impl IntoIterator>>) -> Self { - let vec = vs.into_iter().collect_vec(); - let depth: usize = vec - .iter() - .flat_map(|x| x.iter()) - .map(|x| x.depth() + 1) - .max() - .unwrap_or(0); - Self::Branch(depth, vec) - } - - fn depth(&self) -> usize { - match self { - TestSumType::Branch(x, _) => *x, - TestSumType::Leaf(_) => 0, - } - } - - #[allow(unused)] // ALAN ? - fn is_leaf(&self) -> bool { - self.depth() == 0 - } + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct TestValue(usize); - fn assert_valid(&self) { - match self { - TestSumType::Branch(d, sop) => { - assert!(!sop.is_empty(), "No variants"); - for v in sop.iter().flat_map(|x| x.iter()) { - assert!(v.depth() < *d); - v.assert_valid(); - } - } - TestSumType::Leaf(l) => { - l.assert_valid(); - } - } - } + impl BaseValue for TestValue {} - fn select(self) -> impl Strategy>)>> { - match self { - TestSumType::Branch(_, sop) => any::() - .prop_map(move |i| { - let index = i.index(sop.len()); - Either::Right((index, sop[index].clone())) - }) - .boxed(), - TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), - } - } + #[derive(Clone)] + struct SumTypeParams { + depth: usize, + desired_size: usize, + expected_branch_size: usize, + } - fn get_type(&self) -> Type { - match self { - TestSumType::Branch(_, sop) => Type::new_sum( - sop.iter() - .map(|row| row.iter().map(|x| x.get_type()).collect_vec()), - ), - TestSumType::Leaf(l) => l.get_type(), + impl Default for SumTypeParams { + fn default() -> Self { + Self { + depth: 5, + desired_size: 20, + expected_branch_size: 5, } } + } - fn type_check(&self, pv: &PartialValue) -> bool { + impl TestSumType { + fn type_check(&self, pv: &PartialValue) -> bool { match (self, pv.as_enum()) { (_, PVEnum::Bottom) | (_, PVEnum::Top) => true, - (_, PVEnum::Value(v)) => self.get_type() == v.get_type(), - (TestSumType::Branch(_, sop), PVEnum::Sum(ps)) => { + (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), + (Self::Leaf(Some(max)), PVEnum::Value(TestValue(val))) => val <= max, + (Self::Branch(sop), PVEnum::Sum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { return false; @@ -681,123 +558,126 @@ mod test { } true } - (Self::Leaf(l), PVEnum::Sum(ps)) => l.type_check(ps), - } - } - } - - impl From for TestSumType { - fn from(value: TestSumLeafType) -> Self { - Self::Leaf(value) - } - } - - #[derive(Clone, PartialEq, Eq, Debug)] - struct UnarySumTypeParams { - depth: usize, - branch_width: usize, - } - - impl UnarySumTypeParams { - pub fn descend(mut self, d: usize) -> Self { - assert!(d < self.depth); - self.depth = d; - self - } - } - - impl Default for UnarySumTypeParams { - fn default() -> Self { - Self { - depth: 3, - branch_width: 3, + _ => false, } } } impl Arbitrary for TestSumType { - type Parameters = UnarySumTypeParams; - type Strategy = BoxedStrategy; - fn arbitrary_with( - params @ UnarySumTypeParams { - depth, - branch_width, - }: Self::Parameters, - ) -> Self::Strategy { - if depth == 0 { - any::().prop_map_into().boxed() - } else { - (0..depth) - .prop_flat_map(move |d| { - prop::collection::vec( - prop::collection::vec( - any_with::(params.clone().descend(d)).prop_map_into(), - 0..branch_width, + type Parameters = SumTypeParams; + type Strategy = SBoxedStrategy; + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { + use proptest::collection::vec; + let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); + let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat].sboxed(); + leaf_strat.prop_mutually_recursive( + params.depth as u32, + params.desired_size as u32, + params.expected_branch_size as u32, + set, + move |set| { + let self2 = params.clone(); + vec( + vec( + set.get::(move |set| arb(self2, set)) + .prop_map(Arc::new), + 1..=params.expected_branch_size, ), - 1..=branch_width, + 1..=params.expected_branch_size, ) - .prop_map(TestSumType::branch) - }) - .boxed() - } - } - } - - proptest! { - #[test] - fn unary_sum_type_valid(ust: TestSumType) { - ust.assert_valid(); - } + .prop_map(TestSumType::Branch) + .sboxed() + }, + ) + } + + arb(params, &mut StrategySet::default()) + } + } + + fn partial_sum_strat( + tag: usize, + elems_strat: impl Strategy>>, + ) -> impl Strategy> { + elems_strat.prop_map(move |elems| PartialSum::new_variant(tag, elems)) + } + + // Result gets fed into partial_sum_strat along with tag, so probably inline this into that + fn vec_strat( + elems: &Vec>, + ) -> impl Strategy>> { + elems + .into_iter() + .map(Arc::as_ref) + .map(any_partial_value_of_type) + .collect::>() + } + + fn multi_sum_strat( + variants: &Vec>>, + ) -> impl Strategy> { + let num_tags = variants.len(); + // We have to clone the `variants` here but only as far as the Vec>> + let s = subsequence( + variants.iter().cloned().enumerate().collect::>(), + 1..=num_tags, + ); + let sum_strat: BoxedStrategy>> = s + .prop_flat_map(|selected_tagged_variants| { + selected_tagged_variants + .into_iter() + .map(|(tag, elems)| partial_sum_strat(tag, vec_strat(&elems)).boxed()) + .collect::>() + }) + .boxed(); + sum_strat.prop_map(|psums: Vec>| { + let mut psums = psums.into_iter(); + let first = psums.next().unwrap(); + psums.fold(first, |mut a, b| { + a.try_join_mut(b).unwrap(); + a + }) + }) } fn any_partial_value_of_type( - ust: TestSumType, - ) -> impl Strategy> { - ust.select().prop_flat_map(|x| match x { - Either::Left(l) => l.partial_value_strategy().boxed(), - Either::Right((index, usts)) => { - let pvs = usts - .into_iter() - .map(|x| { - any_partial_value_of_type( - Arc::::try_unwrap(x) - .unwrap_or_else(|x| x.as_ref().clone()), - ) - }) - .collect_vec(); - pvs.prop_map(move |pvs| PartialValue::new_variant(index, pvs)) - .boxed() - } - }) + ust: &TestSumType, + ) -> impl Strategy> { + match ust { + TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), + TestSumType::Leaf(Some(i)) => (0..*i) + .prop_map(TestValue) + .prop_map(PartialValue::from) + .boxed(), + TestSumType::Branch(sop) => multi_sum_strat(sop).prop_map(PartialValue::from).boxed(), + } } fn any_partial_value_with( params: ::Parameters, - ) -> impl Strategy> { - any_with::(params).prop_flat_map(any_partial_value_of_type) + ) -> impl Strategy> { + any_with::(params).prop_flat_map(|t| any_partial_value_of_type(&t)) } - fn any_partial_value() -> impl Strategy> { + fn any_partial_value() -> impl Strategy> { any_partial_value_with(Default::default()) } - fn any_partial_values() -> impl Strategy; N]> - { + fn any_partial_values() -> impl Strategy; N]> { any::().prop_flat_map(|ust| { TryInto::<[_; N]>::try_into( (0..N) - .map(|_| any_partial_value_of_type(ust.clone())) + .map(|_| any_partial_value_of_type(&ust)) .collect_vec(), ) .unwrap() }) } - fn any_typed_partial_value() -> impl Strategy)> - { - any::().prop_flat_map(|t| { - any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v)) - }) + fn any_typed_partial_value() -> impl Strategy)> { + any::() + .prop_flat_map(|t| any_partial_value_of_type(&t).prop_map(move |v| (t.clone(), v))) } proptest! { From 2a57a1517ada040bee083519d232d0339f6c5579 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 17:06:43 +0100 Subject: [PATCH 081/281] tests: Rename type_check -> check_value --- hugr-passes/src/dataflow/partial_value.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 272e43f2a..905f2547c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -538,7 +538,7 @@ mod test { } impl TestSumType { - fn type_check(&self, pv: &PartialValue) -> bool { + fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv.as_enum()) { (_, PVEnum::Bottom) | (_, PVEnum::Top) => true, (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), @@ -552,7 +552,7 @@ mod test { if prod.len() != v.len() { return false; } - if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.check_value(rhs)) { return false; } } @@ -683,7 +683,7 @@ mod test { proptest! { #[test] fn partial_value_type((tst, pv) in any_typed_partial_value()) { - prop_assert!(tst.type_check(&pv)) + prop_assert!(tst.check_value(&pv)) } // todo: ValidHandle is valid From fcfcb6b8b748a44f668d067b016e0ea227904903 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 17:09:30 +0100 Subject: [PATCH 082/281] tidies --- hugr-passes/src/dataflow/partial_value.rs | 50 ++++++++++------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 905f2547c..a3bea6b19 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -570,17 +570,17 @@ mod test { fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { use proptest::collection::vec; let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); - let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat].sboxed(); + let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; leaf_strat.prop_mutually_recursive( params.depth as u32, params.desired_size as u32, params.expected_branch_size as u32, set, move |set| { - let self2 = params.clone(); + let params2 = params.clone(); vec( vec( - set.get::(move |set| arb(self2, set)) + set.get::(move |set| arb(params2, set)) .prop_map(Arc::new), 1..=params.expected_branch_size, ), @@ -596,42 +596,34 @@ mod test { } } - fn partial_sum_strat( + fn single_sum_strat( tag: usize, - elems_strat: impl Strategy>>, + elems: Vec>, ) -> impl Strategy> { - elems_strat.prop_map(move |elems| PartialSum::new_variant(tag, elems)) - } - - // Result gets fed into partial_sum_strat along with tag, so probably inline this into that - fn vec_strat( - elems: &Vec>, - ) -> impl Strategy>> { elems - .into_iter() + .iter() .map(Arc::as_ref) .map(any_partial_value_of_type) .collect::>() + .prop_map(move |elems| PartialSum::new_variant(tag, elems)) } - fn multi_sum_strat( + fn partial_sum_strat( variants: &Vec>>, ) -> impl Strategy> { - let num_tags = variants.len(); // We have to clone the `variants` here but only as far as the Vec>> - let s = subsequence( - variants.iter().cloned().enumerate().collect::>(), - 1..=num_tags, - ); - let sum_strat: BoxedStrategy>> = s - .prop_flat_map(|selected_tagged_variants| { - selected_tagged_variants - .into_iter() - .map(|(tag, elems)| partial_sum_strat(tag, vec_strat(&elems)).boxed()) - .collect::>() - }) - .boxed(); - sum_strat.prop_map(|psums: Vec>| { + let tagged_variants = variants.iter().cloned().enumerate().collect::>(); + // The type annotation here (and the .boxed() enabling it) are just for documentation + let sum_variants_strat: BoxedStrategy>> = + subsequence(tagged_variants, 1..=variants.len()) + .prop_flat_map(|selected_variants| { + selected_variants + .into_iter() + .map(|(tag, elems)| single_sum_strat(tag, elems)) + .collect::>() + }) + .boxed(); + sum_variants_strat.prop_map(|psums: Vec>| { let mut psums = psums.into_iter(); let first = psums.next().unwrap(); psums.fold(first, |mut a, b| { @@ -650,7 +642,7 @@ mod test { .prop_map(TestValue) .prop_map(PartialValue::from) .boxed(), - TestSumType::Branch(sop) => multi_sum_strat(sop).prop_map(PartialValue::from).boxed(), + TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), } } From dcaa928517c5914518c24b26edaea784b8b46c0a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 17:56:37 +0100 Subject: [PATCH 083/281] Add a couple more proptests, and a TEMPORARY FIX for a BUG pending better answer --- hugr-passes/src/dataflow/partial_value.rs | 28 ++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index a3bea6b19..dcce791db 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -451,7 +451,16 @@ impl Lattice for PartialValue { _ => unreachable!(), }; match ps1.try_meet_mut(ps2) { - Ok(ch) => ch, + Ok(ch) => { + // ALAN the 'invariant' that a PartialSum always has >=1 tag can be broken here. + // Fix this by rewriting to Bottom, but should probably be refactored - at the + // least, it seems dangerous to expose a potentially-invalidating try_meet_mut. + if ps1.0.is_empty() { + assert!(ch); + self.0 = PVEnum::Bottom + } + ch + } Err(_) => { self.0 = PVEnum::Bottom; true @@ -712,10 +721,27 @@ mod test { let meet = v1.clone().meet(v2.clone()); prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + prop_assert!(meet == v2.clone().meet(v1.clone()), "meet not symmetric"); + prop_assert!(meet == meet.clone().meet(v1.clone()), "repeated meet should be a no-op"); + prop_assert!(meet == meet.clone().meet(v2.clone()), "repeated meet should be a no-op"); let join = v1.clone().join(v2.clone()); prop_assert!(join >= v1, "join not >=: {:#?}", &join); prop_assert!(join >= v2, "join not >=: {:#?}", &join); + prop_assert!(join == v2.clone().join(v1.clone()), "join not symmetric"); + prop_assert!(join == join.clone().join(v1.clone()), "repeated join should be a no-op"); + prop_assert!(join == join.clone().join(v2.clone()), "repeated join should be a no-op"); + } + + #[test] + fn lattice_associative([v1, v2, v3] in any_partial_values()) { + let a = v1.clone().meet(v2.clone()).meet(v3.clone()); + let b = v1.clone().meet(v2.clone().meet(v3.clone())); + prop_assert!(a==b, "meet not associative"); + + let a = v1.clone().join(v2.clone()).join(v3.clone()); + let b = v1.clone().join(v2.clone().join(v3.clone())); + prop_assert!(a==b, "join not associative") } } } From bcacbcca2d23991771bcc49d711161943c547be3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 20:07:56 +0100 Subject: [PATCH 084/281] Remove redundant test --- hugr-passes/src/dataflow/test.rs | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 127dcc373..6f7b2ef52 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -64,27 +64,6 @@ fn test_unpack_tuple_const() { assert_eq!(o2_r, Value::true_val()); } -#[test] -fn test_unpack_const() { - let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); - let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); - let [o] = builder - .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) - .unwrap() - .outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - - let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); - - let o_r = machine - .read_out_wire(o) - .unwrap() - .try_into_wire_value(&hugr, o) - .unwrap(); - assert_eq!(o_r, Value::true_val()); -} - #[test] fn test_tail_loop_never_iterates() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); From 5192ed8ab8ed1b66d0c1ae9afd8da408f0361abd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 20:48:57 +0100 Subject: [PATCH 085/281] Refactor TailLoopTermination::from_control_value --- hugr-passes/src/dataflow/machine.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 986fafa76..5a86dd98e 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -108,10 +108,12 @@ pub enum TailLoopTermination { impl TailLoopTermination { pub fn from_control_value(v: &impl AbstractValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); - if may_break && !may_continue { - Self::ExactlyZeroContinues - } else if may_break && may_continue { - Self::Top + if may_break { + if may_continue { + Self::Top + } else { + Self::ExactlyZeroContinues + } } else { Self::Bottom } From e21bbd759ea13176764200086863dfc5f78075e2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 20:57:40 +0100 Subject: [PATCH 086/281] pub TailLoopTermination, rename members, doc --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 22 ++++++++++++++++------ hugr-passes/src/dataflow/test.rs | 6 +++--- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index a66edde03..6f437f882 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -5,7 +5,7 @@ mod datalog; pub use datalog::{AbstractValue, DFContext}; mod machine; -pub use machine::Machine; +pub use machine::{Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 5a86dd98e..aa9408cdb 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -98,24 +98,34 @@ impl> Machine { } } +/// Tells whether a loop iterates (never, always, sometimes) #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub enum TailLoopTermination { - Bottom, - ExactlyZeroContinues, - Top, + /// The loop never exits (is an infinite loop); no value is ever + /// returned out of the loop. (aka, Bottom.) + // TODO what about a loop that never exits OR continues because of a nested infinite loop? + NeverBreaks, + /// The loop never iterates (so is equivalent to a [DFG](hugr_core::ops::DFG), + /// modulo untupling of the control value) + NeverContinues, + /// The loop might iterate and/or exit. (aka, Top) + BreaksAndContinues, } impl TailLoopTermination { + /// Extracts the relevant information from a value that should represent + /// the value provided to the [Output](hugr_core::ops::Output) node child + /// of the [TailLoop](hugr_core::ops::TailLoop) pub fn from_control_value(v: &impl AbstractValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { - Self::Top + Self::BreaksAndContinues } else { - Self::ExactlyZeroContinues + Self::NeverContinues } } else { - Self::Bottom + Self::NeverBreaks } } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 6f7b2ef52..e9aeb4c57 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -95,7 +95,7 @@ fn test_tail_loop_never_iterates() { .unwrap(); assert_eq!(o_r, r_v); assert_eq!( - Some(TailLoopTermination::ExactlyZeroContinues), + Some(TailLoopTermination::NeverContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) ) } @@ -127,7 +127,7 @@ fn test_tail_loop_always_iterates() { let o_r2 = machine.read_out_wire(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( - Some(TailLoopTermination::Bottom), + Some(TailLoopTermination::NeverBreaks), machine.tail_loop_terminates(&hugr, tail_loop.node()) ); assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); @@ -181,7 +181,7 @@ fn test_tail_loop_iterates_twice() { let _ = machine.read_out_wire(o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( - Some(TailLoopTermination::Top), + Some(TailLoopTermination::BreaksAndContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) ); assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); From 22e0192dbbf21cd8156871cecc9c791151b76cae Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 21:53:56 +0100 Subject: [PATCH 087/281] Test tidies (and some ALAN wtf? comments) --- hugr-passes/src/dataflow/test.rs | 45 +++++++++++--------------------- 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e9aeb4c57..408c9949d 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,7 +1,4 @@ -use crate::{ - const_fold2::HugrValueContext, - dataflow::{machine::TailLoopTermination, AbstractValue, Machine}, -}; +use crate::const_fold2::HugrValueContext; use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::{ @@ -10,13 +7,13 @@ use hugr_core::{ prelude::{UnpackTuple, BOOL_T}, ExtensionSet, EMPTY_REG, }, - ops::{handle::NodeHandle, OpTrait, Value}, + ops::{handle::NodeHandle, DataflowOpTrait, Value}, type_row, - types::{Signature, SumType, Type, TypeRow}, + types::{Signature, SumType, Type}, HugrView, }; -use super::partial_value::PartialValue; +use super::{AbstractValue, Machine, PartialValue, TailLoopTermination}; #[test] fn test_make_tuple() { @@ -85,8 +82,6 @@ fn test_tail_loop_never_iterates() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - // dbg!(&machine.tail_loop_io_node); - // dbg!(&machine.out_wire_value); let o_r = machine .read_out_wire(tl_o) @@ -152,34 +147,26 @@ fn test_tail_loop_iterates_twice() { ) .unwrap(); assert_eq!( - tlb.loop_signature().unwrap().dataflow_signature().unwrap(), + tlb.loop_signature().unwrap().signature(), Signature::new_endo(type_row![BOOL_T, BOOL_T]) ); let [in_w1, in_w2] = tlb.input_wires_arr(); let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); - // let optype = builder.hugr().get_optype(tail_loop.node()); - // for p in builder.hugr().node_outputs(tail_loop.node()) { - // use hugr_core::ops::OpType; - // println!("{:?}, {:?}", p, optype.port_kind(p)); - - // } - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); // TODO once we can do conditionals put these wires inside `just_outputs` and - // we should be able to propagate their values + // we should be able to propagate their values...ALAN wtf? loop control type IS bool ATM let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - // dbg!(&machine.tail_loop_io_node); - // dbg!(&machine.out_wire_value); - - // TODO these hould be the propagated values for now they will bt join(true,false) - let _ = machine.read_out_wire(o_w1).unwrap(); - // assert_eq!(o_r1, PartialValue::top()); - let _ = machine.read_out_wire(o_w2).unwrap(); - // assert_eq!(o_r2, Value::true_val()); + + let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); + // TODO these should be the propagated values for now they will be join(true,false) - ALAN wtf? + let o_r1 = machine.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, true_or_false); + let o_r2 = machine.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, true_or_false); assert_eq!( Some(TailLoopTermination::BreaksAndContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) @@ -191,19 +178,17 @@ fn test_tail_loop_iterates_twice() { fn conditional() { let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; let cond_t = Type::new_sum(variants.clone()); - let mut builder = - DFGBuilder::new(Signature::new(Into::::into(cond_t), type_row![])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap(); let [arg_w] = builder.input_wires_arr(); let true_w = builder.add_load_value(Value::true_val()); let false_w = builder.add_load_value(Value::false_val()); let mut cond_builder = builder - .conditional_builder_exts( + .conditional_builder( (variants, arg_w), [(BOOL_T, true_w)], type_row!(BOOL_T, BOOL_T), - ExtensionSet::default(), ) .unwrap(); // will be unreachable From cae5e4fe44e5a5ba68a63fce9814a87e15a6e74d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:07:39 +0100 Subject: [PATCH 088/281] Use Tag --- hugr-passes/src/dataflow/test.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 408c9949d..4cbc3b893 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -7,7 +7,7 @@ use hugr_core::{ prelude::{UnpackTuple, BOOL_T}, ExtensionSet, EMPTY_REG, }, - ops::{handle::NodeHandle, DataflowOpTrait, Value}, + ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, type_row, types::{Signature, SumType, Type}, HugrView, @@ -65,18 +65,14 @@ fn test_unpack_tuple_const() { fn test_tail_loop_never_iterates() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let r_v = Value::unit_sum(3, 6).unwrap(); - let r_w = builder.add_load_value( - Value::sum( - 1, - [r_v.clone()], - SumType::new([type_row![], r_v.get_type().into()]), - ) - .unwrap(), - ); + let r_w = builder.add_load_value(r_v.clone()); + let tag = Tag::new(1, vec![type_row![], r_v.get_type().into()]); + let tagged = builder.add_dataflow_op(tag, [r_w]).unwrap(); + let tlb = builder .tail_loop_builder([], [], vec![r_v.get_type()].into()) .unwrap(); - let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let tail_loop = tlb.finish_with_outputs(tagged.out_wire(0), []).unwrap(); let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); From 3014827c94002d9be8c859d628e19502016c0b1f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:08:38 +0100 Subject: [PATCH 089/281] Add TestContext (no interpret_leaf_op), propolutate, avoid HugrValueContext --- hugr-passes/src/dataflow/test.rs | 94 +++++++++++++++++++++++++++++--- 1 file changed, 86 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 4cbc3b893..4f42b4a4e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,4 +1,5 @@ -use crate::const_fold2::HugrValueContext; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::{ @@ -13,7 +14,80 @@ use hugr_core::{ HugrView, }; -use super::{AbstractValue, Machine, PartialValue, TailLoopTermination}; +use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; + +// ------- Minimal implementation of DFContext and BaseValue ------- +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum Void {} + +impl BaseValue for Void {} + +struct TestContext(Arc); + +// Deriving Clone requires H:HugrView to implement Clone, +// but we don't need that as we only clone the Arc. +impl Clone for TestContext { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl std::ops::Deref for TestContext { + type Target = hugr_core::Hugr; + + fn deref(&self) -> &Self::Target { + self.0.base_hugr() + } +} + +// Any value used in an Ascent program must be hashable. +// However, there should only be one DFContext, so its hash is immaterial. +impl Hash for TestContext { + fn hash(&self, _state: &mut I) {} +} + +impl PartialEq for TestContext { + fn eq(&self, other: &Self) -> bool { + // Any AscentProgram should have only one DFContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); + true + } +} + +impl Eq for TestContext {} + +impl PartialOrd for TestContext { + fn partial_cmp(&self, other: &Self) -> Option { + // Any AscentProgram should have only one DFContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); + Some(std::cmp::Ordering::Equal) + } +} + +impl DFContext> for TestContext { + fn interpret_leaf_op( + &self, + _node: hugr_core::Node, + _ins: &[PartialValue], + ) -> Option>> { + None + } +} + +// This allows testing creation of tuple/sum Values (only) +impl From for Value { + fn from(v: Void) -> Self { + match v {} + } +} + +fn pv_false() -> PartialValue { + PartialValue::new_variant(0, []) +} + +fn pv_true() -> PartialValue { + PartialValue::new_variant(1, []) +} #[test] fn test_make_tuple() { @@ -24,7 +98,8 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.propolutate_out_wires([(v1, pv_false()), (v2, pv_true())]); + machine.run(TestContext(Arc::new(&hugr))); let x = machine .read_out_wire(v3) @@ -45,7 +120,8 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.propolutate_out_wires([(v, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); + machine.run(TestContext(Arc::new(&hugr))); let o1_r = machine .read_out_wire(o1) @@ -77,7 +153,8 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.propolutate_out_wires([(r_w, PartialValue::new_variant(3, []))]); + machine.run(TestContext(Arc::new(&hugr))); let o_r = machine .read_out_wire(tl_o) @@ -111,7 +188,7 @@ fn test_tail_loop_always_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.run(TestContext(Arc::new(&hugr))); let o_r1 = machine.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -155,7 +232,8 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); + machine.run(TestContext(Arc::new(&hugr))); let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); // TODO these should be the propagated values for now they will be join(true,false) - ALAN wtf? @@ -211,7 +289,7 @@ fn conditional() { [PartialValue::new_variant(0, [])], )); machine.propolutate_out_wires([(arg_w, arg_pv)]); - machine.run(HugrValueContext::new(&hugr)); + machine.run(TestContext(Arc::new(&hugr))); let cond_r1 = machine .read_out_wire(cond_o1) From 64b9bb753b4904237569f2ebc39f80dab7c97dca Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:33:07 +0100 Subject: [PATCH 090/281] Avoid propolutate by interpreting LoadConstant (only) --- hugr-passes/src/dataflow/test.rs | 34 +++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 4f42b4a4e..82a8f6f89 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -13,6 +13,7 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; +use itertools::Itertools; use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -67,10 +68,27 @@ impl PartialOrd for TestContext { impl DFContext> for TestContext { fn interpret_leaf_op( &self, - _node: hugr_core::Node, + node: hugr_core::Node, _ins: &[PartialValue], ) -> Option>> { - None + // Interpret LoadConstants of sums of sums (without leaves), only + fn try_into_pv(v: &Value) -> Option> { + let Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) = v else { + return None; + }; + Some(PartialValue::new_variant( + *tag, + values + .iter() + .map(try_into_pv) + .collect::>>>()?, + )) + } + self.0.get_optype(node).as_load_constant().and_then(|_| { + let const_node = self.0.input_neighbours(node).exactly_one().ok().unwrap(); + let v = self.0.get_optype(const_node).as_const().unwrap().value(); + try_into_pv(v).map(|v| vec![v]) + }) } } @@ -81,14 +99,6 @@ impl From for Value { } } -fn pv_false() -> PartialValue { - PartialValue::new_variant(0, []) -} - -fn pv_true() -> PartialValue { - PartialValue::new_variant(1, []) -} - #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); @@ -98,7 +108,6 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(v1, pv_false()), (v2, pv_true())]); machine.run(TestContext(Arc::new(&hugr))); let x = machine @@ -120,7 +129,6 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(v, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o1_r = machine @@ -153,7 +161,6 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(r_w, PartialValue::new_variant(3, []))]); machine.run(TestContext(Arc::new(&hugr))); let o_r = machine @@ -232,7 +239,6 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); From 8bc5e122e6e8b88ff74e5e8775cbb9ba574da221 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:36:34 +0100 Subject: [PATCH 091/281] Revert "Avoid propolutate by interpreting LoadConstant (only)" This reverts commit 64b9bb753b4904237569f2ebc39f80dab7c97dca. --- hugr-passes/src/dataflow/test.rs | 34 +++++++++++++------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 82a8f6f89..4f42b4a4e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -13,7 +13,6 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; -use itertools::Itertools; use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -68,27 +67,10 @@ impl PartialOrd for TestContext { impl DFContext> for TestContext { fn interpret_leaf_op( &self, - node: hugr_core::Node, + _node: hugr_core::Node, _ins: &[PartialValue], ) -> Option>> { - // Interpret LoadConstants of sums of sums (without leaves), only - fn try_into_pv(v: &Value) -> Option> { - let Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) = v else { - return None; - }; - Some(PartialValue::new_variant( - *tag, - values - .iter() - .map(try_into_pv) - .collect::>>>()?, - )) - } - self.0.get_optype(node).as_load_constant().and_then(|_| { - let const_node = self.0.input_neighbours(node).exactly_one().ok().unwrap(); - let v = self.0.get_optype(const_node).as_const().unwrap().value(); - try_into_pv(v).map(|v| vec![v]) - }) + None } } @@ -99,6 +81,14 @@ impl From for Value { } } +fn pv_false() -> PartialValue { + PartialValue::new_variant(0, []) +} + +fn pv_true() -> PartialValue { + PartialValue::new_variant(1, []) +} + #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); @@ -108,6 +98,7 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); + machine.propolutate_out_wires([(v1, pv_false()), (v2, pv_true())]); machine.run(TestContext(Arc::new(&hugr))); let x = machine @@ -129,6 +120,7 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); + machine.propolutate_out_wires([(v, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o1_r = machine @@ -161,6 +153,7 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); + machine.propolutate_out_wires([(r_w, PartialValue::new_variant(3, []))]); machine.run(TestContext(Arc::new(&hugr))); let o_r = machine @@ -239,6 +232,7 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); + machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); From a3a6213f40d4ac6b22bb7b70f66f00778fe4851b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:20:48 +0100 Subject: [PATCH 092/281] tiny const_fold2 doc tweaks --- hugr-passes/src/const_fold2/context.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 32fc57765..6338629df 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -8,15 +8,13 @@ use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::value_handle::{ValueHandle, ValueKey}; use crate::dataflow::TotalContext; -/// A context ([DFContext]) for doing analysis with [ValueHandle]s. +/// A [context](crate::dataflow::DFContext) for doing analysis with [ValueHandle]s. /// Interprets [LoadConstant](OpType::LoadConstant) nodes, /// and [ExtensionOp](OpType::ExtensionOp) nodes where the extension does /// (using [Value]s for extension-op inputs). /// /// Just stores a Hugr (actually any [HugrView]), /// (there is )no state for operation-interpretation. -/// -/// [DFContext]: crate::dataflow::DFContext #[derive(Debug)] pub struct HugrValueContext(Arc); From a96ab20ac7f24001c81ccbd846b94bbc86efc722 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:22:04 +0100 Subject: [PATCH 093/281] (TEMP) remove const_fold2 module --- hugr-passes/src/const_fold2.rs | 8 - hugr-passes/src/const_fold2/context.rs | 101 --------- hugr-passes/src/const_fold2/value_handle.rs | 222 -------------------- hugr-passes/src/lib.rs | 1 - 4 files changed, 332 deletions(-) delete mode 100644 hugr-passes/src/const_fold2.rs delete mode 100644 hugr-passes/src/const_fold2/context.rs delete mode 100644 hugr-passes/src/const_fold2/value_handle.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs deleted file mode 100644 index 58f285d43..000000000 --- a/hugr-passes/src/const_fold2.rs +++ /dev/null @@ -1,8 +0,0 @@ -#![warn(missing_docs)] -//! An (example) use of the [super::dataflow](dataflow-analysis framework) -//! to perform constant-folding. - -// These are pub because this "example" is used for testing the framework. -mod context; -pub mod value_handle; -pub use context::HugrValueContext; diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs deleted file mode 100644 index 6338629df..000000000 --- a/hugr-passes/src/const_fold2/context.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::hash::{Hash, Hasher}; -use std::ops::Deref; -use std::sync::Arc; - -use hugr_core::ops::{OpType, Value}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; - -use super::value_handle::{ValueHandle, ValueKey}; -use crate::dataflow::TotalContext; - -/// A [context](crate::dataflow::DFContext) for doing analysis with [ValueHandle]s. -/// Interprets [LoadConstant](OpType::LoadConstant) nodes, -/// and [ExtensionOp](OpType::ExtensionOp) nodes where the extension does -/// (using [Value]s for extension-op inputs). -/// -/// Just stores a Hugr (actually any [HugrView]), -/// (there is )no state for operation-interpretation. -#[derive(Debug)] -pub struct HugrValueContext(Arc); - -impl HugrValueContext { - /// Creates a new instance, given ownership of the [HugrView] - pub fn new(hugr: H) -> Self { - Self(Arc::new(hugr)) - } -} - -// Deriving Clone requires H:HugrView to implement Clone, -// but we don't need that as we only clone the Arc. -impl Clone for HugrValueContext { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -// Any value used in an Ascent program must be hashable. -// However, there should only be one DFContext, so its hash is immaterial. -impl Hash for HugrValueContext { - fn hash(&self, _state: &mut I) {} -} - -impl PartialEq for HugrValueContext { - fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - true - } -} - -impl Eq for HugrValueContext {} - -impl PartialOrd for HugrValueContext { - fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - Some(std::cmp::Ordering::Equal) - } -} - -impl Deref for HugrValueContext { - type Target = Hugr; - - fn deref(&self) -> &Self::Target { - self.0.base_hugr() - } -} - -impl TotalContext for HugrValueContext { - type InterpretableVal = Value; - - fn interpret_leaf_op( - &self, - n: Node, - ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort, ValueHandle)> { - match self.0.get_optype(n) { - OpType::LoadConstant(load_op) => { - assert!(ins.is_empty()); // static edge, so need to find constant - let const_node = self - .0 - .single_linked_output(n, load_op.constant_port()) - .unwrap() - .0; - let const_op = self.0.get_optype(const_node).as_const().unwrap(); - vec![( - OutgoingPort::from(0), - ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())), - )] - } - OpType::ExtensionOp(op) => { - let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); - op.constant_fold(&ins).map_or(Vec::new(), |outs| { - outs.into_iter() - .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))) - .collect() - }) - } - _ => vec![], - } - } -} diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs deleted file mode 100644 index 59a08b50a..000000000 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ /dev/null @@ -1,222 +0,0 @@ -use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - -use hugr_core::ops::constant::{CustomConst, Sum}; -use hugr_core::ops::Value; -use hugr_core::types::Type; -use hugr_core::Node; - -use crate::dataflow::BaseValue; - -#[derive(Clone, Debug)] -pub struct HashedConst { - hash: u64, - val: Arc, -} - -impl PartialEq for HashedConst { - fn eq(&self, other: &Self) -> bool { - self.hash == other.hash && self.val.equal_consts(other.val.as_ref()) - } -} - -impl Eq for HashedConst {} - -impl Hash for HashedConst { - fn hash(&self, state: &mut H) { - state.write_u64(self.hash); - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum ValueKey { - Field(usize, Box), - Const(HashedConst), - Node(Node), -} - -impl From for ValueKey { - fn from(n: Node) -> Self { - Self::Node(n) - } -} - -impl From for ValueKey { - fn from(value: HashedConst) -> Self { - Self::Const(value) - } -} - -impl ValueKey { - pub fn new(n: Node, k: impl CustomConst) -> Self { - Self::try_new(k).unwrap_or(Self::Node(n)) - } - - pub fn try_new(cst: impl CustomConst) -> Option { - let mut hasher = DefaultHasher::new(); - cst.try_hash(&mut hasher).then(|| { - Self::Const(HashedConst { - hash: hasher.finish(), - val: Arc::new(cst), - }) - }) - } - - fn field(self, i: usize) -> Self { - Self::Field(i, Box::new(self)) - } -} - -#[derive(Clone, Debug)] -pub struct ValueHandle(ValueKey, Arc); - -impl ValueHandle { - pub fn new(key: ValueKey, value: Arc) -> Self { - Self(key, value) - } - - pub fn value(&self) -> &Value { - self.1.as_ref() - } - - pub fn get_type(&self) -> Type { - self.1.get_type() - } -} - -impl BaseValue for ValueHandle { - fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { - match self.value() { - Value::Sum(Sum { tag, values, .. }) => Some(( - *tag, - values - .iter() - .enumerate() - .map(|(i, v)| Self(self.0.clone().field(i), Arc::new(v.clone()))), - )), - _ => None, - } - } -} - -impl PartialEq for ValueHandle { - fn eq(&self, other: &Self) -> bool { - // If the keys are equal, we return true since the values must have the - // same provenance, and so be equal. If the keys are different but the - // values are equal, we could return true if we didn't impl Eq, but - // since we do impl Eq, the Hash contract prohibits us from having equal - // values with different hashes. - let r = self.0 == other.0; - if r { - debug_assert_eq!(self.get_type(), other.get_type()); - } - r - } -} - -impl Eq for ValueHandle {} - -impl Hash for ValueHandle { - fn hash(&self, state: &mut I) { - self.0.hash(state); - } -} - -impl From for Value { - fn from(value: ValueHandle) -> Self { - (*value.1).clone() - } -} - -#[cfg(test)] -mod test { - use hugr_core::{ - extension::prelude::ConstString, - ops::constant::CustomConst as _, - std_extensions::{ - arithmetic::{ - float_types::{ConstF64, FLOAT64_TYPE}, - int_types::{ConstInt, INT_TYPES}, - }, - collections::ListValue, - }, - types::SumType, - }; - - use super::*; - - #[test] - fn value_key_eq() { - let n = Node::from(portgraph::NodeIndex::new(0)); - let n2: Node = portgraph::NodeIndex::new(1).into(); - let k1 = ValueKey::new(n, ConstString::new("foo".to_string())); - let k2 = ValueKey::new(n2, ConstString::new("foo".to_string())); - let k3 = ValueKey::new(n, ConstString::new("bar".to_string())); - - assert_eq!(k1, k2); // Node ignored - assert_ne!(k1, k3); - - assert_eq!(ValueKey::from(n), ValueKey::from(n)); - let f = ConstF64::new(std::f64::consts::PI); - assert_eq!(ValueKey::new(n, f.clone()), ValueKey::from(n)); - - assert_ne!(ValueKey::new(n, f.clone()), ValueKey::new(n2, f)); // Node taken into account - let k4 = ValueKey::from(n); - let k5 = ValueKey::from(n); - let k6: ValueKey = ValueKey::from(n2); - - assert_eq!(&k4, &k5); - assert_ne!(&k4, &k6); - - let k7 = k5.clone().field(3); - let k4 = k4.field(3); - - assert_eq!(&k4, &k7); - - let k5 = k5.field(2); - - assert_ne!(&k5, &k7); - } - - #[test] - fn value_key_list() { - let v1 = ConstInt::new_u(3, 3).unwrap(); - let v2 = ConstInt::new_u(4, 3).unwrap(); - let v3 = ConstF64::new(std::f64::consts::PI); - - let n = Node::from(portgraph::NodeIndex::new(0)); - let n2: Node = portgraph::NodeIndex::new(1).into(); - - let lst = ListValue::new(INT_TYPES[0].clone(), [v1.into(), v2.into()]); - assert_eq!(ValueKey::new(n, lst.clone()), ValueKey::new(n2, lst)); - - let lst = ListValue::new(FLOAT64_TYPE, [v3.into()]); - assert_ne!( - ValueKey::new(n, lst.clone()), - ValueKey::new(n2, lst.clone()) - ); - } - - #[test] - fn value_handle_eq() { - let k_i = ConstInt::new_u(4, 2).unwrap(); - let subject_val = Arc::new( - Value::sum( - 0, - [k_i.clone().into()], - SumType::new([vec![k_i.get_type()], vec![]]), - ) - .unwrap(), - ); - - let k1 = ValueKey::try_new(ConstString::new("foo".to_string())).unwrap(); - let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); - let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); - - let fields = v1.as_sum().unwrap().1.collect::>(); - // we do not compare the value, just the key - assert_ne!(fields[0], v2); - assert_eq!(fields[0].value(), v2.value()); - } -} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 0b73fcbb0..06781f7c5 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,7 +1,6 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; -pub mod const_fold2; pub mod dataflow; pub mod force_order; mod half_node; From 777694ca8db5568d6da13d191591e2b4806c6508 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:55:27 +0100 Subject: [PATCH 094/281] (TEMP) Rm total_context --- hugr-passes/src/dataflow.rs | 3 -- hugr-passes/src/dataflow/total_context.rs | 54 ----------------------- 2 files changed, 57 deletions(-) delete mode 100644 hugr-passes/src/dataflow/total_context.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 6f437f882..6085c3e92 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -10,8 +10,5 @@ pub use machine::{Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; -mod total_context; -pub use total_context::TotalContext; - #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs deleted file mode 100644 index d512912d0..000000000 --- a/hugr-passes/src/dataflow/total_context.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::hash::Hash; - -use ascent::lattice::BoundedLattice; -use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; - -use super::partial_value::{PartialValue, Sum}; -use super::{BaseValue, DFContext}; - -/// A simpler interface like [DFContext] but where the context only cares about -/// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or -/// Sums of potentially multiple variants. -pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { - /// The representation of values on which [Self::interpret_leaf_op] operates - type InterpretableVal: From + TryFrom>; - /// Interpret a leaf op. - /// `ins` gives the input ports for which we know (interpretable) values, and will be non-empty. - /// Returns a list of output ports for which we know (abstract) values (may be empty). - fn interpret_leaf_op( - &self, - node: Node, - ins: &[(IncomingPort, Self::InterpretableVal)], - ) -> Vec<(OutgoingPort, V)>; -} - -impl> DFContext> for T { - fn interpret_leaf_op( - &self, - node: Node, - ins: &[PartialValue], - ) -> Option>> { - let op = self.get_optype(node); - let sig = op.dataflow_signature()?; - let known_ins = sig - .input_types() - .iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_value::<>::InterpretableVal>(ty) - .ok() - .map(|v| (IncomingPort::from(i), v)) - }) - .collect::>(); - let known_outs = self.interpret_leaf_op(node, &known_ins); - (!known_outs.is_empty()).then(|| { - let mut res = vec![PartialValue::bottom(); sig.output_count()]; - for (p, v) in known_outs { - res[p.index()] = v.into(); - } - res - }) - } -} From 9e47b7fc4702555e8e0edc8c1e5550b96a0e61ea Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:56:48 +0100 Subject: [PATCH 095/281] Revert "(TEMP) Rm total_context" This reverts commit 777694ca8db5568d6da13d191591e2b4806c6508. --- hugr-passes/src/dataflow.rs | 3 ++ hugr-passes/src/dataflow/total_context.rs | 54 +++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 hugr-passes/src/dataflow/total_context.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 6085c3e92..6f437f882 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -10,5 +10,8 @@ pub use machine::{Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; +mod total_context; +pub use total_context::TotalContext; + #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs new file mode 100644 index 000000000..d512912d0 --- /dev/null +++ b/hugr-passes/src/dataflow/total_context.rs @@ -0,0 +1,54 @@ +use std::hash::Hash; + +use ascent::lattice::BoundedLattice; +use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; + +use super::partial_value::{PartialValue, Sum}; +use super::{BaseValue, DFContext}; + +/// A simpler interface like [DFContext] but where the context only cares about +/// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or +/// Sums of potentially multiple variants. +pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { + /// The representation of values on which [Self::interpret_leaf_op] operates + type InterpretableVal: From + TryFrom>; + /// Interpret a leaf op. + /// `ins` gives the input ports for which we know (interpretable) values, and will be non-empty. + /// Returns a list of output ports for which we know (abstract) values (may be empty). + fn interpret_leaf_op( + &self, + node: Node, + ins: &[(IncomingPort, Self::InterpretableVal)], + ) -> Vec<(OutgoingPort, V)>; +} + +impl> DFContext> for T { + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>> { + let op = self.get_optype(node); + let sig = op.dataflow_signature()?; + let known_ins = sig + .input_types() + .iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + pv.clone() + .try_into_value::<>::InterpretableVal>(ty) + .ok() + .map(|v| (IncomingPort::from(i), v)) + }) + .collect::>(); + let known_outs = self.interpret_leaf_op(node, &known_ins); + (!known_outs.is_empty()).then(|| { + let mut res = vec![PartialValue::bottom(); sig.output_count()]; + for (p, v) in known_outs { + res[p.index()] = v.into(); + } + res + }) + } +} From 668c030d4ace2812d9f68a0536c5b314bebe8a42 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:56:58 +0100 Subject: [PATCH 096/281] Revert "(TEMP) remove const_fold2 module" This reverts commit a96ab20ac7f24001c81ccbd846b94bbc86efc722. --- hugr-passes/src/const_fold2.rs | 8 + hugr-passes/src/const_fold2/context.rs | 101 +++++++++ hugr-passes/src/const_fold2/value_handle.rs | 222 ++++++++++++++++++++ hugr-passes/src/lib.rs | 1 + 4 files changed, 332 insertions(+) create mode 100644 hugr-passes/src/const_fold2.rs create mode 100644 hugr-passes/src/const_fold2/context.rs create mode 100644 hugr-passes/src/const_fold2/value_handle.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs new file mode 100644 index 000000000..58f285d43 --- /dev/null +++ b/hugr-passes/src/const_fold2.rs @@ -0,0 +1,8 @@ +#![warn(missing_docs)] +//! An (example) use of the [super::dataflow](dataflow-analysis framework) +//! to perform constant-folding. + +// These are pub because this "example" is used for testing the framework. +mod context; +pub mod value_handle; +pub use context::HugrValueContext; diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs new file mode 100644 index 000000000..6338629df --- /dev/null +++ b/hugr-passes/src/const_fold2/context.rs @@ -0,0 +1,101 @@ +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +use hugr_core::ops::{OpType, Value}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; + +use super::value_handle::{ValueHandle, ValueKey}; +use crate::dataflow::TotalContext; + +/// A [context](crate::dataflow::DFContext) for doing analysis with [ValueHandle]s. +/// Interprets [LoadConstant](OpType::LoadConstant) nodes, +/// and [ExtensionOp](OpType::ExtensionOp) nodes where the extension does +/// (using [Value]s for extension-op inputs). +/// +/// Just stores a Hugr (actually any [HugrView]), +/// (there is )no state for operation-interpretation. +#[derive(Debug)] +pub struct HugrValueContext(Arc); + +impl HugrValueContext { + /// Creates a new instance, given ownership of the [HugrView] + pub fn new(hugr: H) -> Self { + Self(Arc::new(hugr)) + } +} + +// Deriving Clone requires H:HugrView to implement Clone, +// but we don't need that as we only clone the Arc. +impl Clone for HugrValueContext { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +// Any value used in an Ascent program must be hashable. +// However, there should only be one DFContext, so its hash is immaterial. +impl Hash for HugrValueContext { + fn hash(&self, _state: &mut I) {} +} + +impl PartialEq for HugrValueContext { + fn eq(&self, other: &Self) -> bool { + // Any AscentProgram should have only one DFContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); + true + } +} + +impl Eq for HugrValueContext {} + +impl PartialOrd for HugrValueContext { + fn partial_cmp(&self, other: &Self) -> Option { + // Any AscentProgram should have only one DFContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); + Some(std::cmp::Ordering::Equal) + } +} + +impl Deref for HugrValueContext { + type Target = Hugr; + + fn deref(&self) -> &Self::Target { + self.0.base_hugr() + } +} + +impl TotalContext for HugrValueContext { + type InterpretableVal = Value; + + fn interpret_leaf_op( + &self, + n: Node, + ins: &[(IncomingPort, Value)], + ) -> Vec<(OutgoingPort, ValueHandle)> { + match self.0.get_optype(n) { + OpType::LoadConstant(load_op) => { + assert!(ins.is_empty()); // static edge, so need to find constant + let const_node = self + .0 + .single_linked_output(n, load_op.constant_port()) + .unwrap() + .0; + let const_op = self.0.get_optype(const_node).as_const().unwrap(); + vec![( + OutgoingPort::from(0), + ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())), + )] + } + OpType::ExtensionOp(op) => { + let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); + op.constant_fold(&ins).map_or(Vec::new(), |outs| { + outs.into_iter() + .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))) + .collect() + }) + } + _ => vec![], + } + } +} diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs new file mode 100644 index 000000000..59a08b50a --- /dev/null +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -0,0 +1,222 @@ +use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use hugr_core::ops::constant::{CustomConst, Sum}; +use hugr_core::ops::Value; +use hugr_core::types::Type; +use hugr_core::Node; + +use crate::dataflow::BaseValue; + +#[derive(Clone, Debug)] +pub struct HashedConst { + hash: u64, + val: Arc, +} + +impl PartialEq for HashedConst { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash && self.val.equal_consts(other.val.as_ref()) + } +} + +impl Eq for HashedConst {} + +impl Hash for HashedConst { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum ValueKey { + Field(usize, Box), + Const(HashedConst), + Node(Node), +} + +impl From for ValueKey { + fn from(n: Node) -> Self { + Self::Node(n) + } +} + +impl From for ValueKey { + fn from(value: HashedConst) -> Self { + Self::Const(value) + } +} + +impl ValueKey { + pub fn new(n: Node, k: impl CustomConst) -> Self { + Self::try_new(k).unwrap_or(Self::Node(n)) + } + + pub fn try_new(cst: impl CustomConst) -> Option { + let mut hasher = DefaultHasher::new(); + cst.try_hash(&mut hasher).then(|| { + Self::Const(HashedConst { + hash: hasher.finish(), + val: Arc::new(cst), + }) + }) + } + + fn field(self, i: usize) -> Self { + Self::Field(i, Box::new(self)) + } +} + +#[derive(Clone, Debug)] +pub struct ValueHandle(ValueKey, Arc); + +impl ValueHandle { + pub fn new(key: ValueKey, value: Arc) -> Self { + Self(key, value) + } + + pub fn value(&self) -> &Value { + self.1.as_ref() + } + + pub fn get_type(&self) -> Type { + self.1.get_type() + } +} + +impl BaseValue for ValueHandle { + fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { + match self.value() { + Value::Sum(Sum { tag, values, .. }) => Some(( + *tag, + values + .iter() + .enumerate() + .map(|(i, v)| Self(self.0.clone().field(i), Arc::new(v.clone()))), + )), + _ => None, + } + } +} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + // If the keys are equal, we return true since the values must have the + // same provenance, and so be equal. If the keys are different but the + // values are equal, we could return true if we didn't impl Eq, but + // since we do impl Eq, the Hash contract prohibits us from having equal + // values with different hashes. + let r = self.0 == other.0; + if r { + debug_assert_eq!(self.get_type(), other.get_type()); + } + r + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + self.0.hash(state); + } +} + +impl From for Value { + fn from(value: ValueHandle) -> Self { + (*value.1).clone() + } +} + +#[cfg(test)] +mod test { + use hugr_core::{ + extension::prelude::ConstString, + ops::constant::CustomConst as _, + std_extensions::{ + arithmetic::{ + float_types::{ConstF64, FLOAT64_TYPE}, + int_types::{ConstInt, INT_TYPES}, + }, + collections::ListValue, + }, + types::SumType, + }; + + use super::*; + + #[test] + fn value_key_eq() { + let n = Node::from(portgraph::NodeIndex::new(0)); + let n2: Node = portgraph::NodeIndex::new(1).into(); + let k1 = ValueKey::new(n, ConstString::new("foo".to_string())); + let k2 = ValueKey::new(n2, ConstString::new("foo".to_string())); + let k3 = ValueKey::new(n, ConstString::new("bar".to_string())); + + assert_eq!(k1, k2); // Node ignored + assert_ne!(k1, k3); + + assert_eq!(ValueKey::from(n), ValueKey::from(n)); + let f = ConstF64::new(std::f64::consts::PI); + assert_eq!(ValueKey::new(n, f.clone()), ValueKey::from(n)); + + assert_ne!(ValueKey::new(n, f.clone()), ValueKey::new(n2, f)); // Node taken into account + let k4 = ValueKey::from(n); + let k5 = ValueKey::from(n); + let k6: ValueKey = ValueKey::from(n2); + + assert_eq!(&k4, &k5); + assert_ne!(&k4, &k6); + + let k7 = k5.clone().field(3); + let k4 = k4.field(3); + + assert_eq!(&k4, &k7); + + let k5 = k5.field(2); + + assert_ne!(&k5, &k7); + } + + #[test] + fn value_key_list() { + let v1 = ConstInt::new_u(3, 3).unwrap(); + let v2 = ConstInt::new_u(4, 3).unwrap(); + let v3 = ConstF64::new(std::f64::consts::PI); + + let n = Node::from(portgraph::NodeIndex::new(0)); + let n2: Node = portgraph::NodeIndex::new(1).into(); + + let lst = ListValue::new(INT_TYPES[0].clone(), [v1.into(), v2.into()]); + assert_eq!(ValueKey::new(n, lst.clone()), ValueKey::new(n2, lst)); + + let lst = ListValue::new(FLOAT64_TYPE, [v3.into()]); + assert_ne!( + ValueKey::new(n, lst.clone()), + ValueKey::new(n2, lst.clone()) + ); + } + + #[test] + fn value_handle_eq() { + let k_i = ConstInt::new_u(4, 2).unwrap(); + let subject_val = Arc::new( + Value::sum( + 0, + [k_i.clone().into()], + SumType::new([vec![k_i.get_type()], vec![]]), + ) + .unwrap(), + ); + + let k1 = ValueKey::try_new(ConstString::new("foo".to_string())).unwrap(); + let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); + let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); + + let fields = v1.as_sum().unwrap().1.collect::>(); + // we do not compare the value, just the key + assert_ne!(fields[0], v2); + assert_eq!(fields[0].value(), v2.value()); + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 06781f7c5..0b73fcbb0 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod const_fold2; pub mod dataflow; pub mod force_order; mod half_node; From 5a16e6bd740cff1d55fd598295b725629187745a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:59:30 +0100 Subject: [PATCH 097/281] clippy --- hugr-passes/src/dataflow/partial_value.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index dcce791db..de1cd55d9 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -618,7 +618,7 @@ mod test { } fn partial_sum_strat( - variants: &Vec>>, + variants: &[Vec>], ) -> impl Strategy> { // We have to clone the `variants` here but only as far as the Vec>> let tagged_variants = variants.iter().cloned().enumerate().collect::>(); From 4f311782a4482639cd7afc451174bed8842c847d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 10:11:10 +0100 Subject: [PATCH 098/281] Better fix for PartialSum::try_meet_mut --- hugr-passes/src/dataflow/partial_value.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index de1cd55d9..0fe37af43 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -101,8 +101,11 @@ impl PartialSum { /// Mutates self according to lattice meet operation (towards `Bottom`). If successful, /// returns whether `self` has changed. /// - /// Fails (without mutation) with the conflicting tag if any common rows have different lengths - pub fn try_meet_mut(&mut self, other: Self) -> Result { + /// # Errors + /// Fails without mutation, either: + /// * `Some(tag)` if the two [PartialSum]s both had rows with that `tag` but of different lengths + /// * `None` if the two instances had no rows in common (i.e., the result is "Bottom") + pub fn try_meet_mut(&mut self, other: Self) -> Result> { let mut changed = false; let mut keys_to_remove = vec![]; for (k, v) in self.0.iter() { @@ -110,11 +113,14 @@ impl PartialSum { None => keys_to_remove.push(*k), Some(o_v) => { if v.len() != o_v.len() { - return Err(*k); + return Err(Some(*k)); } } } } + if keys_to_remove.len() == self.0.len() { + return Err(None); + } for (k, v) in other.0 { if let Some(row) = self.0.get_mut(&k) { for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { @@ -451,16 +457,7 @@ impl Lattice for PartialValue { _ => unreachable!(), }; match ps1.try_meet_mut(ps2) { - Ok(ch) => { - // ALAN the 'invariant' that a PartialSum always has >=1 tag can be broken here. - // Fix this by rewriting to Bottom, but should probably be refactored - at the - // least, it seems dangerous to expose a potentially-invalidating try_meet_mut. - if ps1.0.is_empty() { - assert!(ch); - self.0 = PVEnum::Bottom - } - ch - } + Ok(ch) => ch, Err(_) => { self.0 = PVEnum::Bottom; true From 2b523c90f590b35dd691075fa651416504926b6d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 10:19:55 +0100 Subject: [PATCH 099/281] true_or_false uses pv_true+pv_false --- hugr-passes/src/dataflow/test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 4f42b4a4e..cb131f8d4 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -235,7 +235,7 @@ fn test_tail_loop_iterates_twice() { machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); - let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); + let true_or_false = pv_true().join(pv_false()); // TODO these should be the propagated values for now they will be join(true,false) - ALAN wtf? let o_r1 = machine.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, true_or_false); From 1d2cb9bad3a8b9385fdf5bbce383ea6dfb0063c6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 12:24:24 +0100 Subject: [PATCH 100/281] Update to ascent 0.7.0, drop fn join/meet as these are now trait-default --- hugr-passes/Cargo.toml | 2 +- hugr-passes/src/dataflow/datalog.rs | 10 ---------- hugr-passes/src/dataflow/partial_value.rs | 10 ---------- 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index ff234494e..77b185d31 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -16,7 +16,7 @@ categories = ["compilers"] hugr-core = { path = "../hugr-core", version = "0.10.0" } portgraph = { workspace = true } # This ascent commit has a fix for unsoundness in release/tag 0.6.0: -ascent = {git = "https://github.com/s-arash/ascent", rev="9805d02cb830b6e66abcd4d48836a14cd98366f3"} +ascent = { version = "0.7.0" } downcast-rs = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index fcde4f96b..3d23ca269 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -264,16 +264,6 @@ impl PartialOrd for ValueRow { } impl Lattice for ValueRow { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - fn join_mut(&mut self, other: Self) -> bool { assert_eq!(self.0.len(), other.0.len()); let mut changed = false; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 0fe37af43..2e67d0bb2 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -378,11 +378,6 @@ where } impl Lattice for PartialValue { - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); @@ -425,11 +420,6 @@ impl Lattice for PartialValue { } } - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); match (&self.0, other.0) { From ee91bbeae480e4d6d4b7d1f003d06cb43c357c7e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 12:26:56 +0100 Subject: [PATCH 101/281] Cargo.toml: oops, remove obsolete comment --- hugr-passes/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 77b185d31..cdf782ff3 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -15,7 +15,6 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.10.0" } portgraph = { workspace = true } -# This ascent commit has a fix for unsoundness in release/tag 0.6.0: ascent = { version = "0.7.0" } downcast-rs = { workspace = true } itertools = { workspace = true } From e67051ff6daca7f3958a9af51bf33bba3a8ce7a0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 14:59:19 +0100 Subject: [PATCH 102/281] ValueRow cleanups (remove misleading 'pub's) --- hugr-passes/src/dataflow/datalog.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 3d23ca269..d4d040af5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -221,22 +221,22 @@ fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator(Vec); impl ValueRow { - pub fn new(len: usize) -> Self { + fn new(len: usize) -> Self { Self(vec![PV::bottom(); len]) } - pub fn single_known(len: usize, idx: usize, v: PV) -> Self { + fn single_known(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - pub fn iter(&self) -> impl Iterator { + fn iter(&self) -> impl Iterator { self.0.iter() } - pub fn unpack_first( + fn unpack_first( &self, variant: usize, len: usize, @@ -245,10 +245,6 @@ impl ValueRow { .variant_values(variant, len) .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) } - - // fn initialised(&self) -> bool { - // self.0.iter().all(|x| x != &PV::top()) - // } } impl FromIterator for ValueRow { From 94cee551202af0ba67941819ae0a7582baf12597 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 15:00:28 +0100 Subject: [PATCH 103/281] Refactor: rm tail_node, clone earlier in ValueRow::unpack_first, rm ValueRow::iter Only one use of tail_node that didn't *also* filter to TailLoop nodes itself. Total amount of copying in unpack_first still same, but memory usage increased (clones whole lot in one go). Was required to fix borrow issue with refactor, but same issue otherwise prevents the next commit (CFGs)... --- hugr-passes/src/dataflow/datalog.rs | 30 ++++++++++++----------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d4d040af5..5217ab276 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -114,27 +114,28 @@ ascent::ascent! { // TailLoop - relation tail_loop_node(C, Node); - tail_loop_node(c,n) <-- node(c, n), if c.get_optype(*n).is_tail_loop(); // inputs of tail loop propagate to Input node of child region - out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_node(c, tl), - io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- node(c, tl), + if c.get_optype(*tl).is_tail_loop(), + io_node(c,tl,i, IO::Input), + in_wire_value(c, tl, p, v); // Output node of child region propagate to Input node of child region - out_wire_value(c, in_n, out_p, v) <-- tail_loop_node(c, tl_n), + out_wire_value(c, in_n, out_p, v) <-- node(c, tl_n), + if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), io_node(c,tl_n,in_n, IO::Input), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), + if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), + out_wire_value(c, tl_n, out_p, v) <-- node(c, tl_n), + if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); @@ -232,18 +233,11 @@ impl ValueRow { r } - fn iter(&self) -> impl Iterator { - self.0.iter() - } - - fn unpack_first( - &self, - variant: usize, - len: usize, - ) -> Option + '_> { + fn unpack_first(&self, variant: usize, len: usize) -> Option> { + let rest: Vec<_> = self.0[1..].to_owned(); self[0] .variant_values(variant, len) - .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) + .map(|vals| vals.into_iter().chain(rest)) } } From 5cf5ff0aeb87a5d0aba63dc1fa28ac60b29ed8b7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 15:12:56 +0100 Subject: [PATCH 104/281] Add datalog for CFG --- hugr-passes/src/dataflow/datalog.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 5217ab276..69989d9b0 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -169,6 +169,34 @@ ascent::ascent! { in_wire_value(c, cond, IncomingPort::from(0), v), let reachable = v.supports_tag(*i); + // CFG + relation cfg_node(C, Node); + relation dfb_block(C, Node, Node); + cfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_cfg(); + dfb_block(c,cfg,blk) <-- cfg_node(c, cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); + + // Where do the values "fed" along a control-flow edge come out? + relation _cfg_succ_dest(C, Node, Node, Node); + _cfg_succ_dest(c, cfg, blk, inp) <-- dfb_block(c, cfg, blk), io_node(c, blk, inp, IO::Input); + _cfg_succ_dest(c, cfg, exit, cfg) <-- cfg_node(c, cfg), if let Some(exit) = c.children(*cfg).skip(1).next(); + + // Inputs of CFG propagate to entry block + out_wire_value(c, i_node, OutgoingPort::from(p.index()), v) <-- + cfg_node(c, cfg), + if let Some(entry) = c.children(*cfg).next(), + io_node(c, entry, i_node, IO::Input), + in_wire_value(c, cfg, p, v); + + // Outputs of each block propagated to successor blocks or (if exit block) then CFG itself + out_wire_value(c, dest, out_p, v) <-- + dfb_block(c, cfg, pred), + if let Some(df_block) = c.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), + io_node(c, pred, out_n, IO::Output), + _cfg_succ_dest(c, cfg, succ, dest), + node_in_value_row(c, out_n, out_in_row), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); } fn propagate_leaf_op( From 7381087bf4a04a66be104a2eed32f8a2ae7573d1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 15:45:32 +0100 Subject: [PATCH 105/281] refactor: follow unpack_first with enumerate --- hugr-passes/src/dataflow/datalog.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 69989d9b0..750822f14 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -122,22 +122,22 @@ ascent::ascent! { in_wire_value(c, tl, p, v); // Output node of child region propagate to Input node of child region - out_wire_value(c, in_n, out_p, v) <-- node(c, tl_n), + out_wire_value(c, in_n, OutgoingPort::from(out_p), v) <-- node(c, tl_n), if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), io_node(c,tl_n,in_n, IO::Input), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 - for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); + for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl_n, out_p, v) <-- node(c, tl_n), + out_wire_value(c, tl_n, OutgoingPort::from(out_p), v) <-- node(c, tl_n), if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 - for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); + for (out_p, v) in fields.enumerate(); // Conditional relation conditional_node(C, Node); @@ -149,14 +149,14 @@ ascent::ascent! { if c.get_optype(case).is_case(); // inputs of conditional propagate into case nodes - out_wire_value(c, i_node, i_p, v) <-- + out_wire_value(c, i_node, OutgoingPort::from(out_p), v) <-- case_node(c, cond, case_index, case), io_node(c, case, i_node, IO::Input), node_in_value_row(c, cond, in_row), //in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.get_optype(*cond).as_conditional(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), - for (i_p, v) in (0..).map(OutgoingPort::from).zip(fields); + for (out_p, v) in fields.enumerate(); // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- @@ -188,7 +188,7 @@ ascent::ascent! { in_wire_value(c, cfg, p, v); // Outputs of each block propagated to successor blocks or (if exit block) then CFG itself - out_wire_value(c, dest, out_p, v) <-- + out_wire_value(c, dest, OutgoingPort::from(out_p), v) <-- dfb_block(c, cfg, pred), if let Some(df_block) = c.get_optype(*pred).as_dataflow_block(), for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), @@ -196,7 +196,7 @@ ascent::ascent! { _cfg_succ_dest(c, cfg, succ, dest), node_in_value_row(c, out_n, out_in_row), if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), - for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); + for (out_p, v) in fields.enumerate(); } fn propagate_leaf_op( From 60e33dba2473535965de85533bc46d9f8ea01de0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 17:01:04 +0100 Subject: [PATCH 106/281] Remove comments from test_tail_loop_(iterates_twice->two_iters) --- hugr-passes/src/dataflow/test.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index cb131f8d4..0babf2945 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -202,20 +202,17 @@ fn test_tail_loop_always_iterates() { } #[test] -fn test_tail_loop_iterates_twice() { +fn test_tail_loop_two_iters() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); - // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); let true_w = builder.add_load_value(Value::true_val()); let false_w = builder.add_load_value(Value::false_val()); - // let r_w = builder - // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); let tlb = builder .tail_loop_builder_exts( [], [(BOOL_T, false_w), (BOOL_T, true_w)], - vec![].into(), + type_row![], ExtensionSet::new(), ) .unwrap(); @@ -227,8 +224,6 @@ fn test_tail_loop_iterates_twice() { let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - // TODO once we can do conditionals put these wires inside `just_outputs` and - // we should be able to propagate their values...ALAN wtf? loop control type IS bool ATM let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); @@ -236,7 +231,6 @@ fn test_tail_loop_iterates_twice() { machine.run(TestContext(Arc::new(&hugr))); let true_or_false = pv_true().join(pv_false()); - // TODO these should be the propagated values for now they will be join(true,false) - ALAN wtf? let o_r1 = machine.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, true_or_false); let o_r2 = machine.read_out_wire(o_w2).unwrap(); From ef4f4335d24c22ad8462fe082fe2af1c299e656f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 17:01:15 +0100 Subject: [PATCH 107/281] Add a test of tail loop around conditional --- hugr-passes/src/dataflow/test.rs | 61 ++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 0babf2945..21fa80547 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -242,6 +242,67 @@ fn test_tail_loop_two_iters() { assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); } +#[test] +fn test_tail_loop_containing_conditional() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let body_out_variants = vec![type_row![BOOL_T; 2]; 2]; + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut tlb = builder + .tail_loop_builder_exts( + [(BOOL_T, false_w), (BOOL_T, true_w)], + [], + type_row![BOOL_T, BOOL_T], + ExtensionSet::new(), + ) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().signature(), + Signature::new_endo(type_row![BOOL_T, BOOL_T]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + + // Branch on in_w1, so first iter (false, true) uses false == tag 0 == continue with (true, true) + // second iter (true, true) uses true == tag 1 == break with (true, true) + let mut cond = tlb + .conditional_builder( + (vec![type_row![]; 2], in_w1), + [], + Type::new_sum(body_out_variants.clone()).into(), + ) + .unwrap(); + for (tag, second_output) in [(0, true_w), (1, false_w)] { + let mut case_b = cond.case_builder(tag).unwrap(); + let r = case_b + .add_dataflow_op(Tag::new(tag, body_out_variants), [in_w2, second_output]) + .unwrap() + .outputs(); + case_b.finish_with_outputs(r).unwrap(); + } + let [r] = cond.finish_sub_container().unwrap().outputs_arr(); + + let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); + + let mut machine = Machine::default(); + machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); + machine.run(TestContext(Arc::new(&hugr))); + + let o_r1 = machine.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, pv_true()); + let o_r2 = machine.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, pv_false()); + assert_eq!( + Some(TailLoopTermination::BreaksAndContinues), + machine.tail_loop_terminates(&hugr, tail_loop.node()) + ); + assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); +} + #[test] fn conditional() { let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; From 59354894d06b309e5cd392c22bac0575e5a15fef Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 17:56:19 +0100 Subject: [PATCH 108/281] improve that test - loop input is a sum and the variants have different values --- hugr-passes/src/dataflow/test.rs | 62 ++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 21fa80547..125838010 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -2,6 +2,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -245,42 +246,49 @@ fn test_tail_loop_two_iters() { #[test] fn test_tail_loop_containing_conditional() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); - let body_out_variants = vec![type_row![BOOL_T; 2]; 2]; - - let true_w = builder.add_load_value(Value::true_val()); - let false_w = builder.add_load_value(Value::false_val()); + let control_variants = vec![type_row![BOOL_T;2]; 2]; + let control_t = Type::new_sum(control_variants.clone()); + let body_out_variants = vec![control_t.clone().into(), type_row![BOOL_T; 2]]; + + let init = builder.add_load_value( + Value::sum( + 0, + [Value::false_val(), Value::true_val()], + SumType::new(control_variants.clone()), + ) + .unwrap(), + ); let mut tlb = builder - .tail_loop_builder_exts( - [(BOOL_T, false_w), (BOOL_T, true_w)], - [], - type_row![BOOL_T, BOOL_T], - ExtensionSet::new(), - ) + .tail_loop_builder([(control_t, init)], [], type_row![BOOL_T; 2]) .unwrap(); - assert_eq!( - tlb.loop_signature().unwrap().signature(), - Signature::new_endo(type_row![BOOL_T, BOOL_T]) - ); - let [in_w1, in_w2] = tlb.input_wires_arr(); + let [in_w] = tlb.input_wires_arr(); - // Branch on in_w1, so first iter (false, true) uses false == tag 0 == continue with (true, true) - // second iter (true, true) uses true == tag 1 == break with (true, true) + // Branch on in_wire, so first iter 0(false, true)... let mut cond = tlb .conditional_builder( - (vec![type_row![]; 2], in_w1), + (control_variants.clone(), in_w), [], Type::new_sum(body_out_variants.clone()).into(), ) .unwrap(); - for (tag, second_output) in [(0, true_w), (1, false_w)] { - let mut case_b = cond.case_builder(tag).unwrap(); - let r = case_b - .add_dataflow_op(Tag::new(tag, body_out_variants), [in_w2, second_output]) - .unwrap() - .outputs(); - case_b.finish_with_outputs(r).unwrap(); - } + let mut case0_b = cond.case_builder(0).unwrap(); + let [a, b] = case0_b.input_wires_arr(); + // Builds value for next iter as 1(true, false) by flipping arguments + let [next_input] = case0_b + .add_dataflow_op(Tag::new(1, control_variants), [b, a]) + .unwrap() + .outputs_arr(); + let cont = case0_b + .add_dataflow_op(Tag::new(0, body_out_variants.clone()), [next_input]) + .unwrap(); + case0_b.finish_with_outputs(cont.outputs()).unwrap(); + // Second iter 1(true, false) => exit with (true, false) + let mut case1_b = cond.case_builder(1).unwrap(); + let loop_res = case1_b + .add_dataflow_op(Tag::new(1, body_out_variants), case1_b.input_wires()) + .unwrap(); + case1_b.finish_with_outputs(loop_res.outputs()).unwrap(); let [r] = cond.finish_sub_container().unwrap().outputs_arr(); let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); @@ -289,7 +297,7 @@ fn test_tail_loop_containing_conditional() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); + machine.propolutate_out_wires([(init, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o_r1 = machine.read_out_wire(o_w1).unwrap(); From b19868133a3ed3e5dfabea55404cbcd8f5cb4216 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 10:00:26 +0100 Subject: [PATCH 109/281] clippy/nth --- hugr-core/src/types.rs | 2 -- hugr-passes/src/dataflow/datalog.rs | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 5afab2294..39980d65f 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -16,9 +16,7 @@ use crate::types::type_param::check_type_arg; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; -pub(crate) use poly_func::PolyFuncTypeBase; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; -pub(crate) use signature::FuncTypeBase; pub use signature::{FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::TypeArg; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 750822f14..3b6fb05a4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -178,7 +178,7 @@ ascent::ascent! { // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(C, Node, Node, Node); _cfg_succ_dest(c, cfg, blk, inp) <-- dfb_block(c, cfg, blk), io_node(c, blk, inp, IO::Input); - _cfg_succ_dest(c, cfg, exit, cfg) <-- cfg_node(c, cfg), if let Some(exit) = c.children(*cfg).skip(1).next(); + _cfg_succ_dest(c, cfg, exit, cfg) <-- cfg_node(c, cfg), if let Some(exit) = c.children(*cfg).nth(1); // Inputs of CFG propagate to entry block out_wire_value(c, i_node, OutgoingPort::from(p.index()), v) <-- From ed30f808a51d503cc58996998de34ef04b8790d8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 12:22:14 +0100 Subject: [PATCH 110/281] revert accidental changes to hugr-core/src/types.rs (how?!) --- hugr-core/src/types.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 39980d65f..5afab2294 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -16,7 +16,9 @@ use crate::types::type_param::check_type_arg; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; +pub(crate) use poly_func::PolyFuncTypeBase; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; +pub(crate) use signature::FuncTypeBase; pub use signature::{FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::TypeArg; From 3f7808ade8e149d6d253037f94d945d88bdceede Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 15:18:03 +0100 Subject: [PATCH 111/281] Cleanup conditional, cfg, unpack_first --- hugr-passes/src/dataflow/datalog.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 3b6fb05a4..b64de3aaf 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -153,8 +153,7 @@ ascent::ascent! { case_node(c, cond, case_index, case), io_node(c, case, i_node, IO::Input), node_in_value_row(c, cond, in_row), - //in_wire_value(c, cond, cond_in_p, cond_in_v), - if let Some(conditional) = c.get_optype(*cond).as_conditional(), + let conditional = c.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -190,7 +189,7 @@ ascent::ascent! { // Outputs of each block propagated to successor blocks or (if exit block) then CFG itself out_wire_value(c, dest, OutgoingPort::from(out_p), v) <-- dfb_block(c, cfg, pred), - if let Some(df_block) = c.get_optype(*pred).as_dataflow_block(), + let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), io_node(c, pred, out_n, IO::Output), _cfg_succ_dest(c, cfg, succ, dest), @@ -262,10 +261,8 @@ impl ValueRow { } fn unpack_first(&self, variant: usize, len: usize) -> Option> { - let rest: Vec<_> = self.0[1..].to_owned(); - self[0] - .variant_values(variant, len) - .map(|vals| vals.into_iter().chain(rest)) + let vals = self[0].variant_values(variant, len)?; + Some(vals.into_iter().chain(self.0[1..].to_owned())) } } From 436b63533d908ea9215073ce0c668aac14a3ad0d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 19:39:18 +0100 Subject: [PATCH 112/281] Complex CFG that does a not-XOR...but analysis generally says "true or false" --- hugr-passes/src/dataflow/test.rs | 122 ++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 125838010..e3e8545ec 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,6 +3,9 @@ use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; +use hugr_core::builder::CFGBuilder; +use hugr_core::types::TypeRow; +use hugr_core::Wire; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -14,6 +17,7 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; +use rstest::rstest; use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -90,6 +94,10 @@ fn pv_true() -> PartialValue { PartialValue::new_variant(1, []) } +fn pv_true_or_false() -> PartialValue { + pv_true().join(pv_false()) +} + #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); @@ -231,11 +239,10 @@ fn test_tail_loop_two_iters() { machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); - let true_or_false = pv_true().join(pv_false()); let o_r1 = machine.read_out_wire(o_w1).unwrap(); - assert_eq!(o_r1, true_or_false); + assert_eq!(o_r1, pv_true_or_false()); let o_r2 = machine.read_out_wire(o_w2).unwrap(); - assert_eq!(o_r2, true_or_false); + assert_eq!(o_r2, pv_true_or_false()); assert_eq!( Some(TailLoopTermination::BreaksAndContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) @@ -371,3 +378,112 @@ fn conditional() { assert_eq!(machine.case_reachable(&hugr, case3.node()), Some(true)); assert_eq!(machine.case_reachable(&hugr, cond.node()), None); } + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] // OK +#[case(pv_true(), pv_false(), pv_true_or_false())] // Result should be false ?? +#[case(pv_false(), pv_true(), pv_true_or_false())] // Result should be false ?? +#[case(pv_false(), pv_false(), pv_true_or_false())] // Result should be true?? +#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Result should be true_or_false? TOP means all inputs inside cases are TOP +#[case(PartialValue::top(), pv_false(), PartialValue::top())] // Result should be true_or_false? +fn cfg( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] outp: PartialValue, +) { + // Entry + // /0 1\ + // A --1-> B + // \0 / + // > X < + let mut builder = CFGBuilder::new(Signature::new(type_row![BOOL_T;2], BOOL_T)).unwrap(); + + // entry (i, j) => if i {B(j)} else {A(j, i, true)}, note that (j, i, true) == (j, false, true) + let entry_outs = [type_row![BOOL_T;3], type_row![BOOL_T]]; + let mut entry = builder + .entry_builder(entry_outs.clone(), type_row![]) + .unwrap(); + let [in_i, in_j] = entry.input_wires_arr(); + let mut cond = entry + .conditional_builder( + (vec![type_row![]; 2], in_i), + [], + Type::new_sum(entry_outs.clone()).into(), + ) + .unwrap(); + let mut if_i_true = cond.case_builder(1).unwrap(); + let br_to_b = if_i_true + .add_dataflow_op(Tag::new(1, entry_outs.to_vec()), [in_j]) + .unwrap(); + if_i_true.finish_with_outputs(br_to_b.outputs()).unwrap(); + let mut if_i_false = cond.case_builder(0).unwrap(); + let true_w = if_i_false.add_load_value(Value::true_val()); + let br_to_a = if_i_false + .add_dataflow_op(Tag::new(0, entry_outs.into()), [in_j, in_i, true_w]) + .unwrap(); + if_i_false.finish_with_outputs(br_to_a.outputs()).unwrap(); + + let [res] = cond.finish_sub_container().unwrap().outputs_arr(); + let entry = entry.finish_with_outputs(res, []).unwrap(); + + // A(w, y, z) => if w {B(y)} else {X(z)} + let a_outs = vec![type_row![BOOL_T]; 2]; + let mut a = builder + .block_builder( + type_row![BOOL_T; 3], + vec![type_row![BOOL_T]; 2], + type_row![], + ) + .unwrap(); + let [in_w, in_y, in_z] = a.input_wires_arr(); + let mut cond = a + .conditional_builder( + (vec![type_row![]; 2], in_w), + [], + Type::new_sum(a_outs.clone()).into(), + ) + .unwrap(); + let mut if_w_true = cond.case_builder(1).unwrap(); + let br_to_b = if_w_true + .add_dataflow_op(Tag::new(1, a_outs.clone()), [in_y]) + .unwrap(); + if_w_true.finish_with_outputs(br_to_b.outputs()).unwrap(); + let mut if_w_false = cond.case_builder(0).unwrap(); + let br_to_x = if_w_false + .add_dataflow_op(Tag::new(0, a_outs), [in_z]) + .unwrap(); + if_w_false.finish_with_outputs(br_to_x.outputs()).unwrap(); + let [res] = cond.finish_sub_container().unwrap().outputs_arr(); + let a = a.finish_with_outputs(res, []).unwrap(); + + // B(v) => X(v) + let mut b = builder + .block_builder(type_row![BOOL_T], [type_row![BOOL_T]], type_row![]) + .unwrap(); + let [control] = b + .add_dataflow_op(Tag::new(0, vec![type_row![BOOL_T]]), b.input_wires()) + .unwrap() + .outputs_arr(); + let b = b.finish_with_outputs(control, []).unwrap(); + + let x = builder.exit_block(); + + builder.branch(&entry, 0, &a).unwrap(); + builder.branch(&entry, 1, &b).unwrap(); + builder.branch(&a, 0, &x).unwrap(); + builder.branch(&a, 1, &b).unwrap(); + builder.branch(&b, 0, &x).unwrap(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let [entry_input, _] = hugr.get_io(entry.node()).unwrap(); + let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); + + let mut machine = Machine::default(); + machine.propolutate_out_wires([(in_w0, inp0), (in_w1, inp1), (true_w, pv_true())]); + machine.run(TestContext(Arc::new(&hugr))); + + assert_eq!( + machine.read_out_wire(Wire::new(hugr.root(), 0)).unwrap(), + outp + ); +} From 0374d130c87a41575848ba7a5661daf61974abfc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 16:13:25 +0100 Subject: [PATCH 113/281] Propagate case results to conditional output only if case reached; some test fix --- hugr-passes/src/dataflow/datalog.rs | 6 ++++-- hugr-passes/src/dataflow/test.rs | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b64de3aaf..258d76c0d 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -157,9 +157,11 @@ ascent::ascent! { if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); - // outputs of case nodes propagate to outputs of conditional + // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- - case_node(c, cond, _, case), + case_node(c, cond, i, case), + in_wire_value(c, cond, IncomingPort::from(0), control), + if control.supports_tag(*i), io_node(c, case, o, IO::Output), in_wire_value(c, o, o_p, v); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e3e8545ec..40481855c 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -381,11 +381,11 @@ fn conditional() { #[rstest] #[case(pv_true(), pv_true(), pv_true())] // OK -#[case(pv_true(), pv_false(), pv_true_or_false())] // Result should be false ?? -#[case(pv_false(), pv_true(), pv_true_or_false())] // Result should be false ?? -#[case(pv_false(), pv_false(), pv_true_or_false())] // Result should be true?? +#[case(pv_true(), pv_false(), pv_false())] // OK +#[case(pv_false(), pv_true(), pv_false())] // OK +#[case(pv_false(), pv_false(), pv_true())] // OK #[case(PartialValue::top(), pv_true(), PartialValue::top())] // Result should be true_or_false? TOP means all inputs inside cases are TOP -#[case(PartialValue::top(), pv_false(), PartialValue::top())] // Result should be true_or_false? +#[case(PartialValue::top(), pv_false(), pv_true_or_false())] // OK fn cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, From 6a2dd9e79ab83d92d7acb049d64cc8b1022b90ec Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 16:15:17 +0100 Subject: [PATCH 114/281] More test cases --- hugr-passes/src/dataflow/test.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 40481855c..71b0425bf 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -380,12 +380,16 @@ fn conditional() { } #[rstest] -#[case(pv_true(), pv_true(), pv_true())] // OK -#[case(pv_true(), pv_false(), pv_false())] // OK -#[case(pv_false(), pv_true(), pv_false())] // OK -#[case(pv_false(), pv_false(), pv_true())] // OK -#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Result should be true_or_false? TOP means all inputs inside cases are TOP -#[case(PartialValue::top(), pv_false(), pv_true_or_false())] // OK +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_true(), pv_false(), pv_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_true())] +#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Ideally, result should be true_or_false +#[case(PartialValue::top(), pv_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_true(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false())] fn cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, From a75fee947bc64a0eea89f0d5a13768b44a96ebf1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 16:23:40 +0100 Subject: [PATCH 115/281] refactor as fixture --- hugr-passes/src/dataflow/test.rs | 51 ++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 71b0425bf..44f2bcaba 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -5,7 +5,6 @@ use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::CFGBuilder; use hugr_core::types::TypeRow; -use hugr_core::Wire; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -17,7 +16,8 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; -use rstest::rstest; +use hugr_core::{Hugr, Node, Wire}; +use rstest::{fixture, rstest}; use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -379,22 +379,14 @@ fn conditional() { assert_eq!(machine.case_reachable(&hugr, cond.node()), None); } -#[rstest] -#[case(pv_true(), pv_true(), pv_true())] -#[case(pv_true(), pv_false(), pv_false())] -#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] -#[case(pv_false(), pv_true(), pv_false())] -#[case(pv_false(), pv_false(), pv_true())] -#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] -#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Ideally, result should be true_or_false -#[case(PartialValue::top(), pv_false(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_true(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_false(), pv_true_or_false())] -fn cfg( - #[case] inp0: PartialValue, - #[case] inp1: PartialValue, - #[case] outp: PartialValue, -) { +// Tuple of +// 1. Hugr being a function on bools: (b,c) => !b XOR c +// 2. Input node of entry block +// 3. Wire out from "True" constant +// Result readable from root node outputs +// Inputs should be placed onto out-wires of the Node (2.) +#[fixture] +fn xnor_cfg() -> (Hugr, Node, Wire) { // Entry // /0 1\ // A --1-> B @@ -478,8 +470,29 @@ fn cfg( builder.branch(&a, 1, &b).unwrap(); builder.branch(&b, 0, &x).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let [entry_input, _] = hugr.get_io(entry.node()).unwrap(); + (hugr, entry_input, true_w) +} + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_true(), pv_false(), pv_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_true())] +#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Ideally, result should be true_or_false +#[case(PartialValue::top(), pv_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_true(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false())] +fn test_cfg( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] outp: PartialValue, + xnor_cfg: (Hugr, Node, Wire), +) { + let (hugr, entry_input, true_w) = xnor_cfg; + let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); let mut machine = Machine::default(); From 151e571c82804ac927292e85f681d314657eb87c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 18:59:51 +0100 Subject: [PATCH 116/281] clippy --- hugr-passes/src/dataflow/test.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 44f2bcaba..2d652c7a4 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::CFGBuilder; -use hugr_core::types::TypeRow; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ From e15b04d0806031f894dadb6501c669177811177e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 10:33:30 +0100 Subject: [PATCH 117/281] Revert "Datalog works on any AbstractValue; impl'd by PartialValue for a BaseValue" This reverts commit 7f2a91a5fc5bc26143f5e82543c172e10ebea90d. --- hugr-passes/src/dataflow.rs | 18 ++++- hugr-passes/src/dataflow/datalog.rs | 91 +++++++++-------------- hugr-passes/src/dataflow/machine.rs | 22 +++--- hugr-passes/src/dataflow/partial_value.rs | 77 +++++++++---------- hugr-passes/src/dataflow/test.rs | 6 +- 5 files changed, 103 insertions(+), 111 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 6085c3e92..5cf5d91eb 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,13 +2,27 @@ //! Dataflow analysis of Hugrs. mod datalog; -pub use datalog::{AbstractValue, DFContext}; mod machine; pub use machine::{Machine, TailLoopTermination}; mod partial_value; -pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum}; + +use hugr_core::{Hugr, Node}; +use std::hash::Hash; + +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `V`). +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + /// Given lattice values for each input, produce lattice values for (what we know of) + /// the outputs. Returning `None` indicates nothing can be deduced. + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>>; +} #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 258d76c0d..4ab12e380 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -16,7 +16,11 @@ use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::OpType; use hugr_core::types::Signature; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; + +use super::{AbstractValue, DFContext, PartialValue}; + +type PV = PartialValue; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IO { @@ -24,50 +28,19 @@ pub enum IO { Output, } -/// Clients of the dataflow framework (particular analyses, such as constant folding) -/// must implement this trait (including providing an appropriate domain type `PV`). -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - /// Given lattice values for each input, produce lattice values for (what we know of) - /// the outputs. Returning `None` indicates nothing can be deduced. - fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option>; -} - -/// Values which can be the domain for dataflow analysis. Must be able to deconstructed -/// into (and constructed from) Sums as these determine control flow. -pub trait AbstractValue: BoundedLattice + Clone + Eq + Hash + std::fmt::Debug { - /// Create a new instance representing a Sum with a single known tag - /// and (recursive) representations of the elements within that tag. - fn new_variant(tag: usize, values: impl IntoIterator) -> Self; - - /// New instance of unit type (i.e. the only possible value, with no contents) - fn new_unit() -> Self { - Self::new_variant(0, []) - } - - /// Test whether this value *might* be a Sum with the specified tag. - fn supports_tag(&self, tag: usize) -> bool; - - /// If this value might be a Sum with the specified tag, return values - /// describing the elements of the Sum, otherwise `None`. - /// - /// Implementations must hold the invariant that for all `x`, `tag` and `len`: - /// `x.variant_values(tag, len).is_some() == x.supports_tag(tag)` - fn variant_values(&self, tag: usize, len: usize) -> Option>; -} - ascent::ascent! { - pub(super) struct AscentProgram>; + pub(super) struct AscentProgram>; relation context(C); - relation out_wire_value_proto(Node, OutgoingPort, PV); + relation out_wire_value_proto(Node, OutgoingPort, PV); relation node(C, Node); relation in_wire(C, Node, IncomingPort); relation out_wire(C, Node, OutgoingPort); relation parent_of_node(C, Node, Node); relation io_node(C, Node, Node, IO); - lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); - lattice in_wire_value(C, Node, IncomingPort, PV); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); node(c, n) <-- context(c), for n in c.nodes(); @@ -200,11 +173,11 @@ ascent::ascent! { for (out_p, v) in fields.enumerate(); } -fn propagate_leaf_op( - c: &impl DFContext, +fn propagate_leaf_op( + c: &impl DFContext, n: Node, - ins: &[PV], -) -> Option> { + ins: &[PV], +) -> Option> { match c.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be @@ -248,33 +221,37 @@ fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator(Vec); +struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { fn new(len: usize) -> Self { - Self(vec![PV::bottom(); len]) + Self(vec![PartialValue::bottom(); len]) } - fn single_known(len: usize, idx: usize, v: PV) -> Self { + fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - fn unpack_first(&self, variant: usize, len: usize) -> Option> { + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option>> { let vals = self[0].variant_values(variant, len)?; Some(vals.into_iter().chain(self.0[1..].to_owned())) } } -impl FromIterator for ValueRow { - fn from_iter>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } @@ -300,30 +277,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PV; +impl IntoIterator for ValueRow { + type Item = PartialValue; - type IntoIter = as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec: Index, + Vec>: Index, { - type Output = as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec: IndexMut, + Vec>: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index aa9408cdb..15262d4db 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -2,16 +2,16 @@ use std::collections::HashMap; use hugr_core::{HugrView, Node, PortIndex, Wire}; -use super::{datalog::AscentProgram, AbstractValue, DFContext}; +use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire] -pub struct Machine>( - AscentProgram, - Option>, +pub struct Machine>( + AscentProgram, + Option>>, ); /// derived-Default requires the context to be Defaultable, which is unnecessary @@ -21,10 +21,13 @@ impl> Default for Machine { } } -impl> Machine { +impl> Machine { /// Provide initial values for some wires. /// (For example, if some properties of the Hugr's inputs are known.) - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + pub fn propolutate_out_wires( + &mut self, + wires: impl IntoIterator)>, + ) { assert!(self.1.is_none()); self.0 .out_wire_value_proto @@ -52,7 +55,7 @@ impl> Machine { } /// Gets the lattice value computed by [Self::run] for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.1.as_ref().unwrap().get(&w).cloned() } @@ -113,10 +116,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - /// Extracts the relevant information from a value that should represent - /// the value provided to the [Output](hugr_core::ops::Output) node child - /// of the [TailLoop](hugr_core::ops::TailLoop) - pub fn from_control_value(v: &impl AbstractValue) -> Self { + fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 2e67d0bb2..880a30241 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -8,11 +8,9 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use super::AbstractValue; - -/// Trait for abstract values that can be wrapped by [PartialValue] for dataflow analysis. -/// (Allows the values to represent sums, but does not require this). -pub trait BaseValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { +/// Trait for an underlying domain of abstract values which can form the *elements* of a +/// [PartialValue] and thus be used in dataflow analysis. +pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { /// If the abstract value represents a [Sum] with a single known tag, deconstruct it /// into that tag plus the elements. The default just returns `None` which is /// appropriate if the abstract value never does (in which case [interpret_leaf_op] @@ -65,7 +63,7 @@ impl PartialSum { } } -impl PartialSum { +impl PartialSum { fn assert_invariants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { @@ -232,15 +230,15 @@ impl Hash for PartialSum { } } -/// Wraps some underlying representation of values (that `impl`s [BaseValue]) into -/// a lattice for use in dataflow analysis, including that an instance may be -/// a [PartialSum] of values of the underlying representation +/// Wraps some underlying representation (knowledge) of values into a lattice +/// for use in dataflow analysis, including that an instance may be a [PartialSum] +/// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub struct PartialValue(PVEnum); impl PartialValue { /// Allows to read the enum, which guarantees that we never return [PVEnum::Value] - /// for a value whose [BaseValue::as_sum] is `Some` - any such value will be + /// for a value whose [AbstractValue::as_sum] is `Some` - any such value will be /// in the form of a [PVEnum::Sum] instead. pub fn as_enum(&self) -> &PVEnum { &self.0 @@ -260,7 +258,7 @@ pub enum PVEnum { Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) @@ -274,7 +272,7 @@ impl From> for PartialValue { } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { match &self.0 { PVEnum::Sum(ps) => { @@ -287,30 +285,22 @@ impl PartialValue { } } - /// Extracts a value (in any representation supporting both leaf values and sums) - // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? - pub fn try_into_value + TryFrom>>( - self, - typ: &Type, - ) -> Result>>::Error>> { - match self.0 { - PVEnum::Value(v) => Ok(V2::from(v.clone())), - PVEnum::Sum(ps) => { - let v = ps.try_into_value(typ).map_err(|_| None)?; - V2::try_from(v).map_err(Some) - } - _ => Err(None), - } + /// New instance of a sum with a single known tag. + pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() + } + + /// New instance of unit type (i.e. the only possible value, with no contents) + pub fn new_unit() -> Self { + Self::new_variant(0, []) } -} -impl AbstractValue for PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - fn variant_values(&self, tag: usize, len: usize) -> Option>> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match &self.0 { PVEnum::Bottom => return None, PVEnum::Value(v) => { @@ -325,7 +315,7 @@ impl AbstractValue for PartialValue { } /// Tells us whether this value might be a Sum with the specified `tag` - fn supports_tag(&self, tag: usize) -> bool { + pub fn supports_tag(&self, tag: usize) -> bool { match &self.0 { PVEnum::Bottom => false, PVEnum::Value(v) => { @@ -337,8 +327,20 @@ impl AbstractValue for PartialValue { } } - fn new_variant(tag: usize, values: impl IntoIterator) -> Self { - PartialSum::new_variant(tag, values).into() + /// Extracts a value (in any representation supporting both leaf values and sums) + // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? + pub fn try_into_value + TryFrom>>( + self, + typ: &Type, + ) -> Result>>::Error>> { + match self.0 { + PVEnum::Value(v) => Ok(V2::from(v.clone())), + PVEnum::Sum(ps) => { + let v = ps.try_into_value(typ).map_err(|_| None)?; + V2::try_from(v).map_err(Some) + } + _ => Err(None), + } } } @@ -350,7 +352,7 @@ impl TryFrom> for Value { } } -impl PartialValue +impl PartialValue where Value: From, { @@ -377,7 +379,7 @@ where } } -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); @@ -463,7 +465,7 @@ impl Lattice for PartialValue { } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self(PVEnum::Top) } @@ -501,8 +503,7 @@ mod test { use proptest_recurse::{StrategyExt, StrategySet}; - use super::{BaseValue, PVEnum, PartialSum, PartialValue}; - use crate::dataflow::AbstractValue; + use super::{AbstractValue, PVEnum, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { @@ -514,7 +515,7 @@ mod test { #[derive(Clone, Debug, PartialEq, Eq, Hash)] struct TestValue(usize); - impl BaseValue for TestValue {} + impl AbstractValue for TestValue {} #[derive(Clone)] struct SumTypeParams { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 2d652c7a4..66c1c80f5 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -18,13 +18,13 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{AbstractValue, DFContext, Machine, PartialValue, TailLoopTermination}; // ------- Minimal implementation of DFContext and BaseValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum Void {} -impl BaseValue for Void {} +impl AbstractValue for Void {} struct TestContext(Arc); @@ -68,7 +68,7 @@ impl PartialOrd for TestContext { } } -impl DFContext> for TestContext { +impl DFContext for TestContext { fn interpret_leaf_op( &self, _node: hugr_core::Node, From 949ef70763a244db4edbe7db7c0c6ee204ba26cc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 10:41:17 +0100 Subject: [PATCH 118/281] Update const_fold2/value_handle.rs, dataflow/total_context.rs --- hugr-passes/src/const_fold2/value_handle.rs | 4 ++-- hugr-passes/src/dataflow/total_context.rs | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 59a08b50a..bbcd25129 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -7,7 +7,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use crate::dataflow::BaseValue; +use crate::dataflow::AbstractValue; #[derive(Clone, Debug)] pub struct HashedConst { @@ -85,7 +85,7 @@ impl ValueHandle { } } -impl BaseValue for ValueHandle { +impl AbstractValue for ValueHandle { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { match self.value() { Value::Sum(Sum { tag, values, .. }) => Some(( diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index d512912d0..9bc0a417e 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -3,8 +3,7 @@ use std::hash::Hash; use ascent::lattice::BoundedLattice; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::partial_value::{PartialValue, Sum}; -use super::{BaseValue, DFContext}; +use super::{AbstractValue, DFContext, PartialValue, Sum}; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or @@ -22,7 +21,7 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { ) -> Vec<(OutgoingPort, V)>; } -impl> DFContext> for T { +impl> DFContext for T { fn interpret_leaf_op( &self, node: Node, From dc08f0d221cc67516731bfaa45f5db190a69f1d4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 11:03:33 +0100 Subject: [PATCH 119/281] (Re-)remove PVEnum --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 153 ++++++++++------------ 2 files changed, 72 insertions(+), 83 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 5cf5d91eb..f786d62c7 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -7,7 +7,7 @@ mod machine; pub use machine::{Machine, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; use hugr_core::{Hugr, Node}; use std::hash::Hash; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 880a30241..d0f7c7ef4 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -234,26 +234,13 @@ impl Hash for PartialSum { /// for use in dataflow analysis, including that an instance may be a [PartialSum] /// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub struct PartialValue(PVEnum); - -impl PartialValue { - /// Allows to read the enum, which guarantees that we never return [PVEnum::Value] - /// for a value whose [AbstractValue::as_sum] is `Some` - any such value will be - /// in the form of a [PVEnum::Sum] instead. - pub fn as_enum(&self) -> &PVEnum { - &self.0 - } -} - -/// The contents of a [PartialValue], i.e. used as a view. -#[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PVEnum { +pub enum PartialValue { /// No possibilities known (so far) Bottom, /// A single value (of the underlying representation) Value(V), - /// Sum (with perhaps several possible tags) of underlying values - Sum(PartialSum), + /// Sum (with at least one, perhaps several, possible tags) of underlying values + PartialSum(PartialSum), /// Might be more than one distinct value of the underlying type `V` Top, } @@ -262,23 +249,23 @@ impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) - .unwrap_or(Self(PVEnum::Value(v))) + .unwrap_or(Self::Value(v)) } } impl From> for PartialValue { fn from(v: PartialSum) -> Self { - Self(PVEnum::Sum(v)) + Self::PartialSum(v) } } impl PartialValue { fn assert_invariants(&self) { - match &self.0 { - PVEnum::Sum(ps) => { + match self { + Self::PartialSum(ps) => { ps.assert_invariants(); } - PVEnum::Value(v) => { + Self::Value(v) => { assert!(v.as_sum().is_none()) } _ => {} @@ -301,14 +288,14 @@ impl PartialValue { /// /// if the value is believed, for that tag, to have a number of values other than `len` pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { - let vals = match &self.0 { - PVEnum::Bottom => return None, - PVEnum::Value(v) => { + let vals = match self { + PartialValue::Bottom => return None, + PartialValue::Value(v) => { assert!(v.as_sum().is_none()); return None; } - PVEnum::Sum(ps) => ps.variant_values(tag)?, - PVEnum::Top => vec![PartialValue(PVEnum::Top); len], + PartialValue::PartialSum(ps) => ps.variant_values(tag)?, + PartialValue::Top => vec![PartialValue::Top; len], }; assert_eq!(vals.len(), len); Some(vals) @@ -316,14 +303,14 @@ impl PartialValue { /// Tells us whether this value might be a Sum with the specified `tag` pub fn supports_tag(&self, tag: usize) -> bool { - match &self.0 { - PVEnum::Bottom => false, - PVEnum::Value(v) => { + match self { + PartialValue::Bottom => false, + PartialValue::Value(v) => { assert!(v.as_sum().is_none()); false } - PVEnum::Sum(ps) => ps.supports_tag(tag), - PVEnum::Top => true, + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, } } @@ -333,9 +320,9 @@ impl PartialValue { self, typ: &Type, ) -> Result>>::Error>> { - match self.0 { - PVEnum::Value(v) => Ok(V2::from(v.clone())), - PVEnum::Sum(ps) => { + match self { + Self::Value(v) => Ok(V2::from(v.clone())), + Self::PartialSum(ps) => { let v = ps.try_into_value(typ).map_err(|_| None)?; V2::try_from(v).map_err(Some) } @@ -356,8 +343,9 @@ impl PartialValue where Value: From, { - /// Turns this instance into a [Value], if it is either a single [value](PVEnum::Value) or - /// a [sum](PVEnum::Sum) with a single known tag, extracting the desired type from a HugrView and Wire. + /// Turns this instance into a [Value], if it is either a single [Value](Self::Value) or + /// a [Sum](PartialValue::PartialSum) with a single known tag, extracting the desired type + /// from a HugrView and Wire. /// /// # Errors /// `None` if the analysis did not result in a single value on that wire @@ -383,40 +371,41 @@ impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); - match (&self.0, other.0) { - (PVEnum::Top, _) => false, - (_, other @ PVEnum::Top) => { - self.0 = other; + match (&*self, other) { + (Self::Top, _) => false, + (_, other @ Self::Top) => { + *self = other; true } - (_, PVEnum::Bottom) => false, - (PVEnum::Bottom, other) => { - self.0 = other; + (_, Self::Bottom) => false, + (Self::Bottom, other) => { + *self = other; true } - (PVEnum::Value(h1), PVEnum::Value(h2)) => { + (Self::Value(h1), Self::Value(h2)) => { if h1 == &h2 { false } else { - self.0 = PVEnum::Top; + *self = Self::Top; true } } - (PVEnum::Sum(_), PVEnum::Sum(ps2)) => { - let Self(PVEnum::Sum(ps1)) = self else { + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { unreachable!() }; match ps1.try_join_mut(ps2) { Ok(ch) => ch, Err(_) => { - self.0 = PVEnum::Top; + *self = Self::Top; true } } } - (PVEnum::Value(ref v), PVEnum::Sum(_)) | (PVEnum::Sum(_), PVEnum::Value(ref v)) => { + (Self::Value(ref v), Self::PartialSum(_)) + | (Self::PartialSum(_), Self::Value(ref v)) => { assert!(v.as_sum().is_none()); - self.0 = PVEnum::Top; + *self = Self::Top; true } } @@ -424,41 +413,41 @@ impl Lattice for PartialValue { fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - match (&self.0, other.0) { - (PVEnum::Bottom, _) => false, - (_, other @ PVEnum::Bottom) => { - self.0 = other; + match (&*self, other) { + (Self::Bottom, _) => false, + (_, other @ Self::Bottom) => { + *self = other; true } - (_, PVEnum::Top) => false, - (PVEnum::Top, other) => { - self.0 = other; + (_, Self::Top) => false, + (Self::Top, other) => { + *self = other; true } - (PVEnum::Value(h1), PVEnum::Value(h2)) => { + (Self::Value(h1), Self::Value(h2)) => { if h1 == &h2 { false } else { - self.0 = PVEnum::Bottom; + *self = Self::Bottom; true } } - (PVEnum::Sum(_), PVEnum::Sum(ps2)) => { - let ps1 = match &mut self.0 { - PVEnum::Sum(ps1) => ps1, - _ => unreachable!(), + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { + unreachable!() }; match ps1.try_meet_mut(ps2) { Ok(ch) => ch, Err(_) => { - self.0 = PVEnum::Bottom; + *self = Self::Bottom; true } } } - (PVEnum::Value(ref v), PVEnum::Sum(_)) | (PVEnum::Sum(_), PVEnum::Value(ref v)) => { + (Self::Value(ref v), Self::PartialSum(_)) + | (Self::PartialSum(_), Self::Value(ref v)) => { assert!(v.as_sum().is_none()); - self.0 = PVEnum::Bottom; + *self = Self::Bottom; true } } @@ -467,26 +456,26 @@ impl Lattice for PartialValue { impl BoundedLattice for PartialValue { fn top() -> Self { - Self(PVEnum::Top) + Self::Top } fn bottom() -> Self { - Self(PVEnum::Bottom) + Self::Bottom } } impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; - match (&self.0, &other.0) { - (PVEnum::Bottom, PVEnum::Bottom) => Some(Ordering::Equal), - (PVEnum::Top, PVEnum::Top) => Some(Ordering::Equal), - (PVEnum::Bottom, _) => Some(Ordering::Less), - (_, PVEnum::Bottom) => Some(Ordering::Greater), - (PVEnum::Top, _) => Some(Ordering::Greater), - (_, PVEnum::Top) => Some(Ordering::Less), - (PVEnum::Value(v1), PVEnum::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), - (PVEnum::Sum(ps1), PVEnum::Sum(ps2)) => ps1.partial_cmp(ps2), + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } } @@ -503,7 +492,7 @@ mod test { use proptest_recurse::{StrategyExt, StrategySet}; - use super::{AbstractValue, PVEnum, PartialSum, PartialValue}; + use super::{AbstractValue, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { @@ -536,11 +525,11 @@ mod test { impl TestSumType { fn check_value(&self, pv: &PartialValue) -> bool { - match (self, pv.as_enum()) { - (_, PVEnum::Bottom) | (_, PVEnum::Top) => true, + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), - (Self::Leaf(Some(max)), PVEnum::Value(TestValue(val))) => val <= max, - (Self::Branch(sop), PVEnum::Sum(ps)) => { + (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { return false; From 436dcd277cf367c0c918e39313414f0f44a9b2d7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 11:19:49 +0100 Subject: [PATCH 120/281] Remove as_sum. AbstractValues are elements not sums --- hugr-passes/src/dataflow/partial_value.rs | 83 ++++++++++------------- 1 file changed, 37 insertions(+), 46 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d0f7c7ef4..0c2f80ca4 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -11,22 +11,24 @@ use std::hash::{Hash, Hasher}; /// Trait for an underlying domain of abstract values which can form the *elements* of a /// [PartialValue] and thus be used in dataflow analysis. pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { - /// If the abstract value represents a [Sum] with a single known tag, deconstruct it - /// into that tag plus the elements. The default just returns `None` which is - /// appropriate if the abstract value never does (in which case [interpret_leaf_op] - /// must produce a [PartialValue::new_variant] for any operation producing - /// a sum). + /// Computes the join of two values (i.e. towards `Top``), if this is representable + /// within the underlying domain. + /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Top]). /// - /// The signature is this way to optimize query/inspection (is-it-a-sum), - /// at the cost of requiring more cloning during actual conversion - /// (inside the lazy Iterator, or for the error case, as Self remains) + /// The default checks equality between `self` and `other` and returns `self` if + /// the two are identical, otherwise `None`. + fn try_join(self, other: Self) -> Option { + (self == other).then_some(self) + } + + /// Computes the meet of two values (i.e. towards `Bottom`), if this is representable + /// within the underlying domain. + /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Bottom]). /// - /// [interpret_leaf_op]: super::DFContext::interpret_leaf_op - /// [Sum]: TypeEnum::Sum - /// [Tag]: hugr_core::ops::Tag - fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { - let res: Option<(usize, as IntoIterator>::IntoIter)> = None; - res + /// The default checks equality between `self` and `other` and returns `self` if + /// the two are identical, otherwise `None`. + fn try_meet(self, other: Self) -> Option { + (self == other).then_some(self) } } @@ -247,9 +249,7 @@ pub enum PartialValue { impl From for PartialValue { fn from(v: V) -> Self { - v.as_sum() - .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) - .unwrap_or(Self::Value(v)) + Self::Value(v) } } @@ -265,9 +265,6 @@ impl PartialValue { Self::PartialSum(ps) => { ps.assert_invariants(); } - Self::Value(v) => { - assert!(v.as_sum().is_none()) - } _ => {} } } @@ -289,11 +286,7 @@ impl PartialValue { /// if the value is believed, for that tag, to have a number of values other than `len` pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { - PartialValue::Bottom => return None, - PartialValue::Value(v) => { - assert!(v.as_sum().is_none()); - return None; - } + PartialValue::Bottom | PartialValue::Value(_) => return None, PartialValue::PartialSum(ps) => ps.variant_values(tag)?, PartialValue::Top => vec![PartialValue::Top; len], }; @@ -304,11 +297,7 @@ impl PartialValue { /// Tells us whether this value might be a Sum with the specified `tag` pub fn supports_tag(&self, tag: usize) -> bool { match self { - PartialValue::Bottom => false, - PartialValue::Value(v) => { - assert!(v.as_sum().is_none()); - false - } + PartialValue::Bottom | PartialValue::Value(_) => false, PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } @@ -382,14 +371,17 @@ impl Lattice for PartialValue { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => { - if h1 == &h2 { - false - } else { + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { + Some(h3) => { + let ch = h3 != *h1; + *self = Self::Value(h3); + ch + } + None => { *self = Self::Top; true } - } + }, (Self::PartialSum(_), Self::PartialSum(ps2)) => { let Self::PartialSum(ps1) = self else { unreachable!() @@ -402,9 +394,7 @@ impl Lattice for PartialValue { } } } - (Self::Value(ref v), Self::PartialSum(_)) - | (Self::PartialSum(_), Self::Value(ref v)) => { - assert!(v.as_sum().is_none()); + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { *self = Self::Top; true } @@ -424,14 +414,17 @@ impl Lattice for PartialValue { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => { - if h1 == &h2 { - false - } else { + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_meet(h2) { + Some(h3) => { + let ch = h3 != *h1; + *self = Self::Value(h3); + ch + } + None => { *self = Self::Bottom; true } - } + }, (Self::PartialSum(_), Self::PartialSum(ps2)) => { let Self::PartialSum(ps1) = self else { unreachable!() @@ -444,9 +437,7 @@ impl Lattice for PartialValue { } } } - (Self::Value(ref v), Self::PartialSum(_)) - | (Self::PartialSum(_), Self::Value(ref v)) => { - assert!(v.as_sum().is_none()); + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { *self = Self::Bottom; true } From e817bbea5c897e3878b38d88ada8fd1e5c125ffb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 14:44:04 +0100 Subject: [PATCH 121/281] clippy --- hugr-passes/src/dataflow/partial_value.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 0c2f80ca4..bdf774f2c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -261,11 +261,8 @@ impl From> for PartialValue { impl PartialValue { fn assert_invariants(&self) { - match self { - Self::PartialSum(ps) => { - ps.assert_invariants(); - } - _ => {} + if let Self::PartialSum(ps) = self { + ps.assert_invariants(); } } From b06cfad2da7ec68efca1d03c66daacc4c1af7808 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 14:00:10 +0100 Subject: [PATCH 122/281] Refactor: remove 'fn input_count' --- hugr-passes/src/dataflow/datalog.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 4ab12e380..b1c649252 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -15,7 +15,6 @@ use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::OpType; -use hugr_core::types::Signature; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; use super::{AbstractValue, DFContext, PartialValue}; @@ -65,8 +64,8 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, ValueRow::new(input_count(c.as_ref(), *n))) <-- node(c, n); - node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input.len(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, ValueRow::new(sig.input_count())) <-- node(c, n), if let Some(sig) = c.signature(*n); + node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); out_wire_value(c, n, p, v) <-- node(c, n), @@ -203,13 +202,6 @@ fn propagate_leaf_op( } } -fn input_count(h: &impl HugrView, n: Node) -> usize { - h.signature(n) - .as_ref() - .map(Signature::input_count) - .unwrap_or(0) -} - fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { h.in_value_types(n).map(|x| x.0) } From fd717be5a2f0f7d9ffcf7976316fd8476057cd5f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 14:24:10 +0100 Subject: [PATCH 123/281] Try to fix interpret_leaf_op: cannot use Bottom for output! But ascent borrowing --- hugr-passes/src/dataflow.rs | 14 +++++++++----- hugr-passes/src/dataflow/datalog.rs | 28 +++++++++++++++++++++------- hugr-passes/src/dataflow/test.rs | 10 +--------- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f786d62c7..51874136c 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -15,13 +15,17 @@ use std::hash::Hash; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - /// Given lattice values for each input, produce lattice values for (what we know of) - /// the outputs. Returning `None` indicates nothing can be deduced. + /// Given lattice values for each input, update lattice values for the (dataflow) outputs. + /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] + /// which is the correct value to leave if nothing can be deduced about that output. + /// (The default does nothing, i.e. leaves `Top` for all outputs.) fn interpret_leaf_op( &self, - node: Node, - ins: &[PartialValue], - ) -> Option>>; + _node: Node, + _ins: &[PartialValue], + _outs: &mut [PartialValue], + ) { + } } #[cfg(test)] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b1c649252..70a404018 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -14,7 +14,7 @@ use std::hash::Hash; use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::OpType; +use hugr_core::ops::{OpTrait, OpType}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; use super::{AbstractValue, DFContext, PartialValue}; @@ -69,9 +69,11 @@ ascent::ascent! { out_wire_value(c, n, p, v) <-- node(c, n), - if !c.get_optype(*n).is_container(), + let op_t = c.get_optype(*n), + if !op_t.is_container(), + if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(c, n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, &vs[..]), + if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count(), &self.out_wire_value_proto[..]), for (p,v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -176,6 +178,8 @@ fn propagate_leaf_op( c: &impl DFContext, n: Node, ins: &[PV], + num_outs: usize, + out_wire_proto: &[(Node, OutgoingPort, PV)], ) -> Option> { match c.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow @@ -195,10 +199,20 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) => None, // handled by parent - // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, - // thus keeping PartialValue hidden, but AbstractValues - // are not necessarily convertible to Value! - _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), + _ => { + // Interpret op. Default/worst-case is that we can't deduce anything about any + // output (just `Top`). + let mut outs = vec![PartialValue::Top; num_outs]; + // However, we may have been told better outcomes: + for (_, p, v) in out_wire_proto.iter().filter(|(n2, _, _)| n == n2) { + outs[p.index()] = v.clone() + } + // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, + // thus keeping PartialValue hidden, but AbstractValues + // are not necessarily convertible to Value! + c.interpret_leaf_op(n, ins, &mut outs[..]); + Some(ValueRow::from_iter(outs)) + } } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 66c1c80f5..455554891 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -68,15 +68,7 @@ impl PartialOrd for TestContext { } } -impl DFContext for TestContext { - fn interpret_leaf_op( - &self, - _node: hugr_core::Node, - _ins: &[PartialValue], - ) -> Option>> { - None - } -} +impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { From 9b174393ef81ea17929d0465b8a6b9b45f8974e3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 14:42:33 +0100 Subject: [PATCH 124/281] interpret_leaf_op for ExtensionOps only; LoadConstant via value_from_(custom_)const(_hugr) --- hugr-passes/src/dataflow.rs | 59 +++++++++++++++++++++++ hugr-passes/src/dataflow/datalog.rs | 33 ++++++++----- hugr-passes/src/dataflow/partial_value.rs | 10 ++-- hugr-passes/src/dataflow/test.rs | 16 ++---- 4 files changed, 91 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 51874136c..370b4d643 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -4,11 +4,13 @@ mod datalog; mod machine; +use hugr_core::ops::constant::CustomConst; pub use machine::{Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; +use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::{Hugr, Node}; use std::hash::Hash; @@ -16,16 +18,73 @@ use std::hash::Hash; /// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. + /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] /// which is the correct value to leave if nothing can be deduced about that output. /// (The default does nothing, i.e. leaves `Top` for all outputs.) + /// + /// [MakeTuple]: hugr_core::extension::prelude::MakeTuple + /// [UnpackTuple]: hugr_core::extension::prelude::UnpackTuple fn interpret_leaf_op( &self, _node: Node, + _e: &ExtensionOp, _ins: &[PartialValue], _outs: &mut [PartialValue], ) { } + + /// Produces an abstract value from a constant. The default impl + /// traverses the constant [Value] to its leaves ([Value::Extension] and [Value::Function]), + /// converts these using [Self::value_from_custom_const] and [Self::value_from_const_hugr], + /// and builds nested [PartialValue::new_variant] to represent the structure. + fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { + traverse_value(self, n, &mut Vec::new(), cst) + } + + /// Produces an abstract value from a [CustomConst], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_custom_const( + &self, + _node: Node, + _fields: &[usize], + _cc: &dyn CustomConst, + ) -> Option { + None + } + + /// Produces an abstract value from a Hugr in a [Value::Function], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_const_hugr(&self, _node: Node, _fields: &[usize], _h: &Hugr) -> Option { + None + } +} + +fn traverse_value( + s: &impl DFContext, + n: Node, + fields: &mut Vec, + cst: &Value, +) -> PartialValue { + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values.iter().enumerate().map(|(idx, elem)| { + fields.push(idx); + let r = traverse_value(s, n, fields, elem); + fields.pop(); + r + }); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => s + .value_from_custom_const(n, fields, e.value()) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => s + .value_from_const_hugr(n, fields, &**hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } } #[cfg(test)] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 70a404018..d88e5cc98 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -73,7 +73,7 @@ ascent::ascent! { if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(c, n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count(), &self.out_wire_value_proto[..]), + if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count()), for (p,v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -179,7 +179,6 @@ fn propagate_leaf_op( n: Node, ins: &[PV], num_outs: usize, - out_wire_proto: &[(Node, OutgoingPort, PV)], ) -> Option> { match c.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow @@ -198,21 +197,31 @@ fn propagate_leaf_op( t.tag, ins.iter().cloned(), )])), - OpType::Input(_) | OpType::Output(_) => None, // handled by parent - _ => { - // Interpret op. Default/worst-case is that we can't deduce anything about any - // output (just `Top`). + OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent + OpType::Const(_) => None, // handled by LoadConstant: + OpType::LoadConstant(load_op) => { + assert!(ins.is_empty()); // static edge, so need to find constant + let const_node = c + .single_linked_output(n, load_op.constant_port()) + .unwrap() + .0; + let const_val = c.get_optype(const_node).as_const().unwrap().value(); + Some(ValueRow::single_known( + 1, + 0, + c.value_from_const(n, const_val), + )) + } + OpType::ExtensionOp(e) => { + // Interpret op. Default is we know nothing about the outputs (they still happen!) let mut outs = vec![PartialValue::Top; num_outs]; - // However, we may have been told better outcomes: - for (_, p, v) in out_wire_proto.iter().filter(|(n2, _, _)| n == n2) { - outs[p.index()] = v.clone() - } // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues - // are not necessarily convertible to Value! - c.interpret_leaf_op(n, ins, &mut outs[..]); + // are not necessarily convertible to Value. + c.interpret_leaf_op(n, e, ins, &mut outs[..]); Some(ValueRow::from_iter(outs)) } + o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index bdf774f2c..d3010ae45 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -63,16 +63,16 @@ impl PartialSum { pub fn num_variants(&self) -> usize { self.0.len() } -} -impl PartialSum { fn assert_invariants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { pv.assert_invariants(); } } +} +impl PartialSum { /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns /// whether `self` has changed. /// @@ -247,7 +247,7 @@ pub enum PartialValue { Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { Self::Value(v) } @@ -259,7 +259,7 @@ impl From> for PartialValue { } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { if let Self::PartialSum(ps) = self { ps.assert_invariants(); @@ -275,7 +275,9 @@ impl PartialValue { pub fn new_unit() -> Self { Self::new_variant(0, []) } +} +impl PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 455554891..a72ade121 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -98,7 +98,6 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(v1, pv_false()), (v2, pv_true())]); machine.run(TestContext(Arc::new(&hugr))); let x = machine @@ -120,7 +119,6 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(v, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o1_r = machine @@ -153,7 +151,6 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(r_w, PartialValue::new_variant(3, []))]); machine.run(TestContext(Arc::new(&hugr))); let o_r = machine @@ -227,7 +224,6 @@ fn test_tail_loop_two_iters() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); let o_r1 = machine.read_out_wire(o_w1).unwrap(); @@ -295,7 +291,6 @@ fn test_tail_loop_containing_conditional() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(init, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o_r1 = machine.read_out_wire(o_w1).unwrap(); @@ -373,11 +368,10 @@ fn conditional() { // Tuple of // 1. Hugr being a function on bools: (b,c) => !b XOR c // 2. Input node of entry block -// 3. Wire out from "True" constant // Result readable from root node outputs // Inputs should be placed onto out-wires of the Node (2.) #[fixture] -fn xnor_cfg() -> (Hugr, Node, Wire) { +fn xnor_cfg() -> (Hugr, Node) { // Entry // /0 1\ // A --1-> B @@ -462,7 +456,7 @@ fn xnor_cfg() -> (Hugr, Node, Wire) { builder.branch(&b, 0, &x).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [entry_input, _] = hugr.get_io(entry.node()).unwrap(); - (hugr, entry_input, true_w) + (hugr, entry_input) } #[rstest] @@ -480,14 +474,14 @@ fn test_cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, #[case] outp: PartialValue, - xnor_cfg: (Hugr, Node, Wire), + xnor_cfg: (Hugr, Node), ) { - let (hugr, entry_input, true_w) = xnor_cfg; + let (hugr, entry_input) = xnor_cfg; let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); let mut machine = Machine::default(); - machine.propolutate_out_wires([(in_w0, inp0), (in_w1, inp1), (true_w, pv_true())]); + machine.propolutate_out_wires([(in_w0, inp0), (in_w1, inp1)]); machine.run(TestContext(Arc::new(&hugr))); assert_eq!( From 1f410568e35c62fde9c2fc7b51d58342def9ae2b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 11:09:18 +0100 Subject: [PATCH 125/281] Odd updates to total_context.rs - REVERT ?? --- hugr-passes/src/dataflow/total_context.rs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 9bc0a417e..dc3c7a69a 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -1,19 +1,15 @@ use std::hash::Hash; -use ascent::lattice::BoundedLattice; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::{AbstractValue, DFContext, PartialValue, Sum}; +use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; +use super::DFContext; /// A simpler interface like [DFContext] but where the context only cares about -/// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or -/// Sums of potentially multiple variants. +/// values that are completely known (in the lattice `V`) +/// rather than e.g. Sums potentially of two variants each of known values. pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { - /// The representation of values on which [Self::interpret_leaf_op] operates - type InterpretableVal: From + TryFrom>; - /// Interpret a leaf op. - /// `ins` gives the input ports for which we know (interpretable) values, and will be non-empty. - /// Returns a list of output ports for which we know (abstract) values (may be empty). + type InterpretableVal: TryFrom>; fn interpret_leaf_op( &self, node: Node, @@ -36,14 +32,17 @@ impl> DFContext for T { .zip(ins.iter()) .filter_map(|((i, ty), pv)| { pv.clone() - .try_into_value::<>::InterpretableVal>(ty) + .try_into_value(ty) + // Discard PVs which don't produce ValueOrSum, i.e. Bottom/Top :-) .ok() + // And discard any ValueOrSum that don't produce V - this is a bit silent :-( + .and_then(|v_s| T::InterpretableVal::try_from(v_s).ok()) .map(|v| (IncomingPort::from(i), v)) }) .collect::>(); let known_outs = self.interpret_leaf_op(node, &known_ins); (!known_outs.is_empty()).then(|| { - let mut res = vec![PartialValue::bottom(); sig.output_count()]; + let mut res = vec![PartialValue::Bottom; sig.output_count()]; for (p, v) in known_outs { res[p.index()] = v.into(); } From 846d1ee57fd057a20e4f522dd7e265956cd89572 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 09:39:51 +0100 Subject: [PATCH 126/281] Correct comment BaseValue -> AbstractValue --- hugr-passes/src/dataflow/test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index a72ade121..56da25e5b 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -20,7 +20,7 @@ use rstest::{fixture, rstest}; use super::{AbstractValue, DFContext, Machine, PartialValue, TailLoopTermination}; -// ------- Minimal implementation of DFContext and BaseValue ------- +// ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum Void {} From 328e7f805e31f840fe90c89b13ba734c5c4a933d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 09:25:20 +0100 Subject: [PATCH 127/281] test Hugr now returns (XOR, AND) of two inputs, one case wrongly producing T|F --- hugr-passes/src/dataflow/test.rs | 116 +++++++++++++++++-------------- 1 file changed, 62 insertions(+), 54 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 56da25e5b..9ac49e6d1 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::CFGBuilder; +use hugr_core::builder::{CFGBuilder, Container}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -366,86 +366,88 @@ fn conditional() { } // Tuple of -// 1. Hugr being a function on bools: (b,c) => !b XOR c +// 1. Hugr being a function on bools: (x, y) => (x XOR y, x AND y) // 2. Input node of entry block // Result readable from root node outputs // Inputs should be placed onto out-wires of the Node (2.) #[fixture] -fn xnor_cfg() -> (Hugr, Node) { +fn xor_and_cfg() -> (Hugr, Node) { // Entry // /0 1\ - // A --1-> B - // \0 / + // A --1-> B A(x=true, y) => if y then X(false, true) else B(x=true) + // \0 / B(z) => X(z,false) // > X < - let mut builder = CFGBuilder::new(Signature::new(type_row![BOOL_T;2], BOOL_T)).unwrap(); - - // entry (i, j) => if i {B(j)} else {A(j, i, true)}, note that (j, i, true) == (j, false, true) - let entry_outs = [type_row![BOOL_T;3], type_row![BOOL_T]]; + let mut builder = + CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap(); + let false_c = builder.add_constant(Value::false_val()); + // entry (x, y) => if x {A(y, x=true)} else B(y)} + let entry_outs = [type_row![BOOL_T;2], type_row![BOOL_T]]; let mut entry = builder .entry_builder(entry_outs.clone(), type_row![]) .unwrap(); - let [in_i, in_j] = entry.input_wires_arr(); + let [in_x, in_y] = entry.input_wires_arr(); let mut cond = entry .conditional_builder( - (vec![type_row![]; 2], in_i), + (vec![type_row![]; 2], in_x), [], Type::new_sum(entry_outs.clone()).into(), ) .unwrap(); - let mut if_i_true = cond.case_builder(1).unwrap(); - let br_to_b = if_i_true - .add_dataflow_op(Tag::new(1, entry_outs.to_vec()), [in_j]) + let mut if_x_true = cond.case_builder(1).unwrap(); + let br_to_a = if_x_true + .add_dataflow_op(Tag::new(0, entry_outs.to_vec()), [in_y, in_x]) .unwrap(); - if_i_true.finish_with_outputs(br_to_b.outputs()).unwrap(); - let mut if_i_false = cond.case_builder(0).unwrap(); - let true_w = if_i_false.add_load_value(Value::true_val()); - let br_to_a = if_i_false - .add_dataflow_op(Tag::new(0, entry_outs.into()), [in_j, in_i, true_w]) + if_x_true.finish_with_outputs(br_to_a.outputs()).unwrap(); + let mut if_x_false = cond.case_builder(0).unwrap(); + let br_to_b = if_x_false + .add_dataflow_op(Tag::new(1, entry_outs.into()), [in_y]) .unwrap(); - if_i_false.finish_with_outputs(br_to_a.outputs()).unwrap(); + if_x_false.finish_with_outputs(br_to_b.outputs()).unwrap(); let [res] = cond.finish_sub_container().unwrap().outputs_arr(); let entry = entry.finish_with_outputs(res, []).unwrap(); - // A(w, y, z) => if w {B(y)} else {X(z)} - let a_outs = vec![type_row![BOOL_T]; 2]; + // A(y, z always true) => if y {X(false, z)} else {B(z)} + let a_outs = vec![type_row![BOOL_T], type_row![]]; let mut a = builder .block_builder( - type_row![BOOL_T; 3], - vec![type_row![BOOL_T]; 2], - type_row![], + type_row![BOOL_T; 2], + a_outs.clone(), + type_row![BOOL_T], // Trailing z common to both branches ) .unwrap(); - let [in_w, in_y, in_z] = a.input_wires_arr(); + let [in_y, in_z] = a.input_wires_arr(); + let mut cond = a .conditional_builder( - (vec![type_row![]; 2], in_w), + (vec![type_row![]; 2], in_y), [], Type::new_sum(a_outs.clone()).into(), ) .unwrap(); - let mut if_w_true = cond.case_builder(1).unwrap(); - let br_to_b = if_w_true - .add_dataflow_op(Tag::new(1, a_outs.clone()), [in_y]) + let mut if_y_true = cond.case_builder(1).unwrap(); + let false_w1 = if_y_true.load_const(&false_c); + let br_to_x = if_y_true + .add_dataflow_op(Tag::new(0, a_outs.clone()), [false_w1]) .unwrap(); - if_w_true.finish_with_outputs(br_to_b.outputs()).unwrap(); - let mut if_w_false = cond.case_builder(0).unwrap(); - let br_to_x = if_w_false - .add_dataflow_op(Tag::new(0, a_outs), [in_z]) - .unwrap(); - if_w_false.finish_with_outputs(br_to_x.outputs()).unwrap(); + if_y_true.finish_with_outputs(br_to_x.outputs()).unwrap(); + let mut if_y_false = cond.case_builder(0).unwrap(); + let br_to_b = if_y_false.add_dataflow_op(Tag::new(1, a_outs), []).unwrap(); + if_y_false.finish_with_outputs(br_to_b.outputs()).unwrap(); let [res] = cond.finish_sub_container().unwrap().outputs_arr(); - let a = a.finish_with_outputs(res, []).unwrap(); + let a = a.finish_with_outputs(res, [in_z]).unwrap(); - // B(v) => X(v) + // B(v) => X(v, false) let mut b = builder - .block_builder(type_row![BOOL_T], [type_row![BOOL_T]], type_row![]) + .block_builder(type_row![BOOL_T], [type_row![]], type_row![BOOL_T; 2]) .unwrap(); + let [in_v] = b.input_wires_arr(); + let false_w2 = b.load_const(&false_c); let [control] = b - .add_dataflow_op(Tag::new(0, vec![type_row![BOOL_T]]), b.input_wires()) + .add_dataflow_op(Tag::new(0, vec![type_row![]]), []) .unwrap() .outputs_arr(); - let b = b.finish_with_outputs(control, []).unwrap(); + let b = b.finish_with_outputs(control, [in_v, false_w2]).unwrap(); let x = builder.exit_block(); @@ -460,23 +462,25 @@ fn xnor_cfg() -> (Hugr, Node) { } #[rstest] -#[case(pv_true(), pv_true(), pv_true())] -#[case(pv_true(), pv_false(), pv_false())] -#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] -#[case(pv_false(), pv_true(), pv_false())] -#[case(pv_false(), pv_false(), pv_true())] -#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] -#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Ideally, result should be true_or_false -#[case(PartialValue::top(), pv_false(), pv_true_or_false())] +#[should_panic] // first case failing +#[case(pv_true(), pv_true(), pv_false(), pv_true())] +#[case(pv_true(), pv_false(), pv_true(), pv_false())] +//#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_false(), pv_false())] +/*#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::top(), pv_true(), pv_true_or_false())] +#[case(PartialValue::top(), pv_false(), PartialValue::top())] // Ideally pv_true_or_false #[case(pv_true_or_false(), pv_true(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false())]*/ fn test_cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, - #[case] outp: PartialValue, - xnor_cfg: (Hugr, Node), + #[case] out0: PartialValue, + #[case] out1: PartialValue, + xor_and_cfg: (Hugr, Node), ) { - let (hugr, entry_input) = xnor_cfg; + let (hugr, entry_input) = xor_and_cfg; let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); @@ -486,6 +490,10 @@ fn test_cfg( assert_eq!( machine.read_out_wire(Wire::new(hugr.root(), 0)).unwrap(), - outp + out0 + ); + assert_eq!( + machine.read_out_wire(Wire::new(hugr.root(), 1)).unwrap(), + out1 ); } From da3c05c0594c56893153697842ceae4d33b1b4ef Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 16:00:18 +0100 Subject: [PATCH 128/281] BB reachability, fixes! --- hugr-passes/src/dataflow/datalog.rs | 13 ++++++++++++- hugr-passes/src/dataflow/machine.rs | 18 ++++++++++++++++++ hugr-passes/src/dataflow/test.rs | 1 - 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d88e5cc98..f3c4868a8 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -150,6 +150,16 @@ ascent::ascent! { cfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_cfg(); dfb_block(c,cfg,blk) <-- cfg_node(c, cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); + // Reachability + relation bb_reachable(C, Node, Node); + bb_reachable(c, cfg, entry) <-- cfg_node(c, cfg), if let Some(entry) = c.children(*cfg).next(); + bb_reachable(c, cfg, bb) <-- cfg_node(c, cfg), + bb_reachable(c, cfg, pred), + io_node(c, pred, pred_out, IO::Output), + in_wire_value(c, pred_out, IncomingPort::from(0), predicate), + for (tag, bb) in c.output_neighbours(*pred).enumerate(), + if predicate.supports_tag(tag); + // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(C, Node, Node, Node); _cfg_succ_dest(c, cfg, blk, inp) <-- dfb_block(c, cfg, blk), io_node(c, blk, inp, IO::Input); @@ -162,9 +172,10 @@ ascent::ascent! { io_node(c, entry, i_node, IO::Input), in_wire_value(c, cfg, p, v); - // Outputs of each block propagated to successor blocks or (if exit block) then CFG itself + // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself out_wire_value(c, dest, OutgoingPort::from(out_p), v) <-- dfb_block(c, cfg, pred), + bb_reachable(c, cfg, pred), let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), io_node(c, pred, out_n, IO::Output), diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 15262d4db..44744b5ee 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -99,6 +99,24 @@ impl> Machine { .unwrap(), ) } + + /// Tells us if a block ([DataflowBlock] or [ExitBlock]) in a [CFG] is known + /// to be reachable. (Returns `None` if argument is not a child of a CFG.) + pub fn bb_reachable(&self, hugr: impl HugrView, bb: Node) -> Option { + let cfg = hugr.get_parent(bb)?; // Not really required...?? + hugr.get_optype(cfg).as_cfg()?; + let t = hugr.get_optype(bb); + if !t.is_dataflow_block() && !t.is_exit_block() { + return None; + }; + Some( + self.0 + .bb_reachable + .iter() + .find(|(_, cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) + .is_some(), + ) + } } /// Tells whether a loop iterates (never, always, sometimes) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 9ac49e6d1..7777cd3f7 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -462,7 +462,6 @@ fn xor_and_cfg() -> (Hugr, Node) { } #[rstest] -#[should_panic] // first case failing #[case(pv_true(), pv_true(), pv_false(), pv_true())] #[case(pv_true(), pv_false(), pv_true(), pv_false())] //#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] From 6930ad4358e3293320a62004a865694b39395ada Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 09:37:32 +0100 Subject: [PATCH 129/281] Test cases with true_or_false/top, standardize naming (->test_)conditional --- hugr-passes/src/dataflow/test.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 7777cd3f7..f3739062e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -305,7 +305,7 @@ fn test_tail_loop_containing_conditional() { } #[test] -fn conditional() { +fn test_conditional() { let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; let cond_t = Type::new_sum(variants.clone()); let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap(); @@ -464,14 +464,16 @@ fn xor_and_cfg() -> (Hugr, Node) { #[rstest] #[case(pv_true(), pv_true(), pv_false(), pv_true())] #[case(pv_true(), pv_false(), pv_true(), pv_false())] -//#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true(), PartialValue::Top, pv_true_or_false(), pv_true_or_false())] #[case(pv_false(), pv_true(), pv_true(), pv_false())] #[case(pv_false(), pv_false(), pv_false(), pv_false())] -/*#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] -#[case(PartialValue::top(), pv_true(), pv_true_or_false())] -#[case(PartialValue::top(), pv_false(), PartialValue::top())] // Ideally pv_true_or_false -#[case(pv_true_or_false(), pv_true(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_false(), pv_true_or_false())]*/ +#[case(pv_false(), pv_true_or_false(), pv_true_or_false(), pv_false())] +#[case(pv_false(), PartialValue::Top, PartialValue::Top, pv_false())] // if !inp0 then out0=inp1 +#[case(pv_true_or_false(), pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_false())] +#[case(PartialValue::Top, pv_true(), pv_true_or_false(), PartialValue::Top)] +#[case(PartialValue::Top, pv_false(), PartialValue::Top, pv_false())] fn test_cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, From 0a3e2812e8815108fdc153dabc10605598cfb36f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 11 Oct 2024 19:58:13 +0100 Subject: [PATCH 130/281] Try to common up by using case_reachable in conditional outputs - 6 tests fail --- hugr-passes/src/dataflow/datalog.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index f3c4868a8..68b09ea0e 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -133,9 +133,8 @@ ascent::ascent! { // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- - case_node(c, cond, i, case), - in_wire_value(c, cond, IncomingPort::from(0), control), - if control.supports_tag(*i), + case_node(c, cond, _, case), + case_reachable(c, cond, case, true), io_node(c, case, o, IO::Output), in_wire_value(c, o, o_p, v); From b1e0bfd4b434a4d4891726528c76f7a338c3bdda Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 11 Oct 2024 20:01:09 +0100 Subject: [PATCH 131/281] Make case_reachable a relation (dropping bool), not lattice - fixes tests --- hugr-passes/src/dataflow/datalog.rs | 8 ++++---- hugr-passes/src/dataflow/machine.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 68b09ea0e..eaf33ade1 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -134,14 +134,14 @@ ascent::ascent! { // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- case_node(c, cond, _, case), - case_reachable(c, cond, case, true), + case_reachable(c, cond, case), io_node(c, case, o, IO::Output), in_wire_value(c, o, o_p, v); - lattice case_reachable(C, Node, Node, bool); - case_reachable(c, cond, case, reachable) <-- case_node(c,cond,i,case), + relation case_reachable(C, Node, Node); + case_reachable(c, cond, case) <-- case_node(c,cond,i,case), in_wire_value(c, cond, IncomingPort::from(0), v), - let reachable = v.supports_tag(*i); + if v.supports_tag(*i); // CFG relation cfg_node(C, Node); diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 44744b5ee..b0439c1b2 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -95,8 +95,8 @@ impl> Machine { self.0 .case_reachable .iter() - .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) - .unwrap(), + .find(|(_, cond2, case2)| &cond == cond2 && &case == case2) + .is_some(), ) } From 355e814f6262acc41f00bf852c1b803382b6883e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 09:59:44 +0100 Subject: [PATCH 132/281] Call (+test) --- hugr-passes/src/dataflow/datalog.rs | 18 +++++++++++++ hugr-passes/src/dataflow/test.rs | 42 ++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index eaf33ade1..5c65d7bca 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -182,6 +182,23 @@ ascent::ascent! { node_in_value_row(c, out_n, out_in_row), if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); + + // Call + relation func_call(C, Node, Node); + func_call(c, call, func_defn) <-- + node(c, call), + if c.get_optype(*call).is_call(), + if let Some(func_defn) = c.static_source(*call); + + out_wire_value(c, inp, OutgoingPort::from(p.index()), v) <-- + func_call(c, call, func), + io_node(c, func, inp, IO::Input), + in_wire_value(c, call, p, v); + + out_wire_value(c, call, OutgoingPort::from(p.index()), v) <-- + func_call(c, call, func), + io_node(c, func, outp, IO::Output), + in_wire_value(c, outp, p, v); } fn propagate_leaf_op( @@ -208,6 +225,7 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent + OpType::Call(_) => None, // handled via Input/Output of FuncDefn OpType::Const(_) => None, // handled by LoadConstant: OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index f3739062e..4e314de49 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container}; +use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -498,3 +498,43 @@ fn test_cfg( out1 ); } + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_false(), pv_false(), pv_false())] +#[case(pv_true(), pv_false(), pv_true_or_false())] // Two calls alias +fn test_call( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] out: PartialValue, +) { + let mut builder = DFGBuilder::new(Signature::new_endo(type_row![BOOL_T; 2])).unwrap(); + let func_bldr = builder + .define_function("id", Signature::new_endo(BOOL_T)) + .unwrap(); + let [v] = func_bldr.input_wires_arr(); + let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); + let [a, b] = builder.input_wires_arr(); + let [a2] = builder + .call(func_defn.handle(), &[], [a], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let [b2] = builder + .call(func_defn.handle(), &[], [b], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let hugr = builder + .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) + .unwrap(); + + let [root_inp, _] = hugr.get_io(hugr.root()).unwrap(); + let [inw0, inw1] = [0, 1].map(|i| Wire::new(root_inp, i)); + let mut machine = Machine::default(); + machine.propolutate_out_wires([(inw0, inp0), (inw1, inp1)]); + machine.run(TestContext(Arc::new(&hugr))); + + let [res0, res1] = [0, 1].map(|i| machine.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); + // The two calls alias so both results will be the same: + assert_eq!(res0, out); + assert_eq!(res1, out); +} From 22f3ce8ecf4888a1ba3de31ae8a1ccbaa16a6ac9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 17:27:32 +0100 Subject: [PATCH 133/281] propolutate_out_wires => prepopulate and set in wires in run --- hugr-passes/src/dataflow/datalog.rs | 16 +++++--- hugr-passes/src/dataflow/machine.rs | 31 ++++++++++------ hugr-passes/src/dataflow/test.rs | 57 +++++++++++++---------------- 3 files changed, 56 insertions(+), 48 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 5c65d7bca..e9e1b92a5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -30,7 +30,6 @@ pub enum IO { ascent::ascent! { pub(super) struct AscentProgram>; relation context(C); - relation out_wire_value_proto(Node, OutgoingPort, PV); relation node(C, Node); relation in_wire(C, Node, IncomingPort); @@ -38,8 +37,8 @@ ascent::ascent! { relation parent_of_node(C, Node, Node); relation io_node(C, Node, Node, IO); lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); lattice in_wire_value(C, Node, IncomingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); node(c, n) <-- context(c), for n in c.nodes(); @@ -53,16 +52,21 @@ ascent::ascent! { io_node(c, parent, child, io) <-- node(c, parent), if let Some([i,o]) = c.get_io(*parent), for (child,io) in [(i,IO::Input),(o,IO::Output)]; - // We support prepopulating out_wire_value via out_wire_value_proto. - // - // out wires that do not have prepopulation values are initialised to bottom. + + // Initialize all wires to bottom out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); - out_wire_value(c, n, p, v) <-- out_wire(c,n,p) , out_wire_value_proto(n, p, v); in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), if let Some((m,op)) = c.single_linked_output(*n, *ip), out_wire_value(c, m, op, v); + // We support prepopulating in_wire_value via in_wire_value_proto. + relation in_wire_value_proto(Node, IncomingPort, PV); + in_wire_value(c, n, p, PV::bottom()) <-- in_wire(c, n,p); + in_wire_value(c, n, p, v) <-- node(c,n), + if let Some(sig) = c.signature(*n), + for p in sig.input_ports(), + in_wire_value_proto(n, p, v); node_in_value_row(c, n, ValueRow::new(sig.input_count())) <-- node(c, n), if let Some(sig) = c.signature(*n); node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index b0439c1b2..8f5a8d5a7 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use hugr_core::{HugrView, Node, PortIndex, Wire}; +use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; @@ -23,26 +23,35 @@ impl> Default for Machine { impl> Machine { /// Provide initial values for some wires. - /// (For example, if some properties of the Hugr's inputs are known.) - pub fn propolutate_out_wires( - &mut self, - wires: impl IntoIterator)>, - ) { + // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? + pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { assert!(self.1.is_none()); - self.0 - .out_wire_value_proto - .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); + self.0.in_wire_value_proto.extend( + h.linked_inputs(wire.node(), wire.source()) + .map(|(n, inp)| (n, inp, value.clone())), + ); } - /// Run the analysis (iterate until a lattice fixpoint is reached). + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// given initial values for some of the root node inputs. + /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, + /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. /// /// # Panics /// /// If this Machine has been run already. /// - pub fn run(&mut self, context: C) { + pub fn run( + &mut self, + context: C, + in_values: impl IntoIterator)>, + ) { assert!(self.1.is_none()); + let root = context.root(); + self.0 + .in_wire_value_proto + .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); self.0.context.push((context,)); self.0.run(); self.1 = Some( diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 4e314de49..23ea73231 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -15,7 +15,7 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; -use hugr_core::{Hugr, Node, Wire}; +use hugr_core::{Hugr, Wire}; use rstest::{fixture, rstest}; use super::{AbstractValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -98,7 +98,7 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let x = machine .read_out_wire(v3) @@ -119,7 +119,7 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o1_r = machine .read_out_wire(o1) @@ -151,7 +151,7 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o_r = machine .read_out_wire(tl_o) @@ -185,7 +185,7 @@ fn test_tail_loop_always_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o_r1 = machine.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -224,7 +224,7 @@ fn test_tail_loop_two_iters() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o_r1 = machine.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -291,7 +291,7 @@ fn test_tail_loop_containing_conditional() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o_r1 = machine.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -344,8 +344,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - machine.propolutate_out_wires([(arg_w, arg_pv)]); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); let cond_r1 = machine .read_out_wire(cond_o1) @@ -365,13 +364,9 @@ fn test_conditional() { assert_eq!(machine.case_reachable(&hugr, cond.node()), None); } -// Tuple of -// 1. Hugr being a function on bools: (x, y) => (x XOR y, x AND y) -// 2. Input node of entry block -// Result readable from root node outputs -// Inputs should be placed onto out-wires of the Node (2.) +// A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) #[fixture] -fn xor_and_cfg() -> (Hugr, Node) { +fn xor_and_cfg() -> Hugr { // Entry // /0 1\ // A --1-> B A(x=true, y) => if y then X(false, true) else B(x=true) @@ -456,9 +451,7 @@ fn xor_and_cfg() -> (Hugr, Node) { builder.branch(&a, 0, &x).unwrap(); builder.branch(&a, 1, &b).unwrap(); builder.branch(&b, 0, &x).unwrap(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let [entry_input, _] = hugr.get_io(entry.node()).unwrap(); - (hugr, entry_input) + builder.finish_hugr(&EMPTY_REG).unwrap() } #[rstest] @@ -479,22 +472,24 @@ fn test_cfg( #[case] inp1: PartialValue, #[case] out0: PartialValue, #[case] out1: PartialValue, - xor_and_cfg: (Hugr, Node), + xor_and_cfg: Hugr, ) { - let (hugr, entry_input) = xor_and_cfg; - - let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); - let mut machine = Machine::default(); - machine.propolutate_out_wires([(in_w0, inp0), (in_w1, inp1)]); - machine.run(TestContext(Arc::new(&hugr))); + machine.run( + TestContext(Arc::new(&xor_and_cfg)), + [(0.into(), inp0), (1.into(), inp1)], + ); assert_eq!( - machine.read_out_wire(Wire::new(hugr.root(), 0)).unwrap(), + machine + .read_out_wire(Wire::new(xor_and_cfg.root(), 0)) + .unwrap(), out0 ); assert_eq!( - machine.read_out_wire(Wire::new(hugr.root(), 1)).unwrap(), + machine + .read_out_wire(Wire::new(xor_and_cfg.root(), 1)) + .unwrap(), out1 ); } @@ -527,11 +522,11 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let [root_inp, _] = hugr.get_io(hugr.root()).unwrap(); - let [inw0, inw1] = [0, 1].map(|i| Wire::new(root_inp, i)); let mut machine = Machine::default(); - machine.propolutate_out_wires([(inw0, inp0), (inw1, inp1)]); - machine.run(TestContext(Arc::new(&hugr))); + machine.run( + TestContext(Arc::new(&hugr)), + [(0.into(), inp0), (1.into(), inp1)], + ); let [res0, res1] = [0, 1].map(|i| machine.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: From 68b1d486efeb18f24515a23e043dd83de17ffe9e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 17:36:01 +0100 Subject: [PATCH 134/281] Rm/inline value_inputs/value_outputs, use UnpackTuple, comments --- hugr-passes/src/dataflow/datalog.rs | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index e9e1b92a5..f664a84d4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -42,9 +42,8 @@ ascent::ascent! { node(c, n) <-- context(c), for n in c.nodes(); - in_wire(c, n,p) <-- node(c, n), for p in value_inputs(c.as_ref(), *n); - - out_wire(c, n,p) <-- node(c, n), for p in value_outputs(c.as_ref(), *n); + in_wire(c, n,p) <-- node(c, n), for (p,_) in c.in_value_types(*n); // Note, gets connected inports only + out_wire(c, n,p) <-- node(c, n), for (p,_) in c.out_value_types(*n); // (and likewise) parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.get_parent(*child); @@ -220,8 +219,9 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), op if op.cast::().is_some() => { + let elem_tys = op.cast::().unwrap().0; let [tup] = ins.iter().collect::>().try_into().unwrap(); - tup.variant_values(0, value_outputs(c.as_ref(), n).count()) + tup.variant_values(0, elem_tys.len()) .map(ValueRow::from_iter) } OpType::Tag(t) => Some(ValueRow::from_iter([PV::new_variant( @@ -257,16 +257,7 @@ fn propagate_leaf_op( } } -fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { - h.in_value_types(n).map(|x| x.0) -} - -fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { - h.out_value_types(n).map(|x| x.0) -} - -// Wrap a (known-length) row of values into a lattice. Perhaps could be part of partial_value.rs? - +// Wrap a (known-length) row of values into a lattice. #[derive(PartialEq, Clone, Eq, Hash)] struct ValueRow(Vec>); From 2cc62f0d17999b4d84ae5b042bd932f592a43a00 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 17:48:28 +0100 Subject: [PATCH 135/281] clippy --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 370b4d643..a861d48b7 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -81,7 +81,7 @@ fn traverse_value( .map(PartialValue::from) .unwrap_or(PartialValue::Top), Value::Function { hugr } => s - .value_from_const_hugr(n, fields, &**hugr) + .value_from_const_hugr(n, fields, hugr) .map(PartialValue::from) .unwrap_or(PartialValue::Top), } diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 8f5a8d5a7..50d7b088b 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -104,8 +104,7 @@ impl> Machine { self.0 .case_reachable .iter() - .find(|(_, cond2, case2)| &cond == cond2 && &case == case2) - .is_some(), + .any(|(_, cond2, case2)| &cond == cond2 && &case == case2), ) } @@ -122,8 +121,7 @@ impl> Machine { self.0 .bb_reachable .iter() - .find(|(_, cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) - .is_some(), + .any(|(_, cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), ) } } From 8254771b22d626c0657722b01d8ab9ba62cce900 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 17:52:01 +0100 Subject: [PATCH 136/281] docs --- hugr-passes/src/dataflow/machine.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 50d7b088b..888e34505 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -6,8 +6,8 @@ use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] -/// 2. Zero or more [Self::propolutate_out_wires] with initial values -/// 3. Exactly one [Self::run] to do the analysis +/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values +/// 3. Exactly one [Self::run], with initial values for root inputs, to do the analysis /// 4. Results then available via [Self::read_out_wire] pub struct Machine>( AscentProgram, @@ -110,6 +110,10 @@ impl> Machine { /// Tells us if a block ([DataflowBlock] or [ExitBlock]) in a [CFG] is known /// to be reachable. (Returns `None` if argument is not a child of a CFG.) + /// + /// [CFG]: hugr_core::ops::CFG + /// [DataflowBlock]: hugr_core::ops::DataflowBlock + /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, hugr: impl HugrView, bb: Node) -> Option { let cfg = hugr.get_parent(bb)?; // Not really required...?? hugr.get_optype(cfg).as_cfg()?; From 7e81b153102f87be1c7ba44d7d008bc884be013c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 09:38:20 +0100 Subject: [PATCH 137/281] Separate AnalysisResults from Machine, use context.exactly_one() not HugrView --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 65 ++++++++++---------- hugr-passes/src/dataflow/test.rs | 92 ++++++++++++----------------- 3 files changed, 72 insertions(+), 87 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index a861d48b7..e04b4c2a8 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -5,7 +5,7 @@ mod datalog; mod machine; use hugr_core::ops::constant::CustomConst; -pub use machine::{Machine, TailLoopTermination}; +pub use machine::{AnalysisResults, Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 888e34505..4ac66686c 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,23 +1,27 @@ use std::collections::HashMap; use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; +use itertools::Itertools; use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Exactly one [Self::run], with initial values for root inputs, to do the analysis -/// 4. Results then available via [Self::read_out_wire] -pub struct Machine>( - AscentProgram, - Option>>, +/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via +/// [read_out_wire](AnalysisResults::read_out_wire) +pub struct Machine>(AscentProgram); + +/// Results of a dataflow analysis. +pub struct AnalysisResults>( + AscentProgram, // Already run - kept for tests/debug + HashMap>, ); /// derived-Default requires the context to be Defaultable, which is unnecessary impl> Default for Machine { fn default() -> Self { - Self(Default::default(), None) + Self(Default::default()) } } @@ -25,7 +29,6 @@ impl> Machine { /// Provide initial values for some wires. // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { - assert!(self.1.is_none()); self.0.in_wire_value_proto.extend( h.linked_inputs(wire.node(), wire.source()) .map(|(n, inp)| (n, inp, value.clone())), @@ -37,35 +40,36 @@ impl> Machine { /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. - /// - /// # Panics - /// - /// If this Machine has been run already. - /// pub fn run( - &mut self, + mut self, context: C, in_values: impl IntoIterator)>, - ) { - assert!(self.1.is_none()); + ) -> AnalysisResults { let root = context.root(); self.0 .in_wire_value_proto .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); self.0.context.push((context,)); self.0.run(); - self.1 = Some( - self.0 - .out_wire_value - .iter() - .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) - .collect(), - ) + let results = self + .0 + .out_wire_value + .iter() + .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(); + AnalysisResults(self.0, results) + } +} + +impl> AnalysisResults { + fn context(&self) -> &C { + let (c,) = self.0.context.iter().exactly_one().ok().unwrap(); + c } - /// Gets the lattice value computed by [Self::run] for the given wire + /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { - self.1.as_ref().unwrap().get(&w).cloned() + self.1.get(&w).cloned() } /// Tells whether a [TailLoop] node can terminate, i.e. whether @@ -73,11 +77,8 @@ impl> Machine { /// Returns `None` if the specified `node` is not a [TailLoop]. /// /// [TailLoop]: hugr_core::ops::TailLoop - pub fn tail_loop_terminates( - &self, - hugr: impl HugrView, - node: Node, - ) -> Option { + pub fn tail_loop_terminates(&self, node: Node) -> Option { + let hugr = self.context(); hugr.get_optype(node).as_tail_loop()?; let [_, out] = hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( @@ -96,7 +97,8 @@ impl> Machine { /// /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional - pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> Option { + pub fn case_reachable(&self, case: Node) -> Option { + let hugr = self.context(); hugr.get_optype(case).as_case()?; let cond = hugr.get_parent(case)?; hugr.get_optype(cond).as_conditional()?; @@ -114,7 +116,8 @@ impl> Machine { /// [CFG]: hugr_core::ops::CFG /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock - pub fn bb_reachable(&self, hugr: impl HugrView, bb: Node) -> Option { + pub fn bb_reachable(&self, bb: Node) -> Option { + let hugr = self.context(); let cfg = hugr.get_parent(bb)?; // Not really required...?? hugr.get_optype(cfg).as_cfg()?; let t = hugr.get_optype(bb); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 23ea73231..e79a0024b 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -97,10 +97,9 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let x = machine + let x = results .read_out_wire(v3) .unwrap() .try_into_wire_value(&hugr, v3) @@ -118,16 +117,15 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o1_r = machine + let o1_r = results .read_out_wire(o1) .unwrap() .try_into_wire_value(&hugr, o1) .unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r = machine + let o2_r = results .read_out_wire(o2) .unwrap() .try_into_wire_value(&hugr, o2) @@ -150,10 +148,9 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r = machine + let o_r = results .read_out_wire(tl_o) .unwrap() .try_into_wire_value(&hugr, tl_o) @@ -161,7 +158,7 @@ fn test_tail_loop_never_iterates() { assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::NeverContinues), - machine.tail_loop_terminates(&hugr, tail_loop.node()) + results.tail_loop_terminates(tail_loop.node()) ) } @@ -184,18 +181,17 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r1 = machine.read_out_wire(tl_o1).unwrap(); + let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); - let o_r2 = machine.read_out_wire(tl_o2).unwrap(); + let o_r2 = results.read_out_wire(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( Some(TailLoopTermination::NeverBreaks), - machine.tail_loop_terminates(&hugr, tail_loop.node()) + results.tail_loop_terminates(tail_loop.node()) ); - assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); } #[test] @@ -223,18 +219,17 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r1 = machine.read_out_wire(o_w1).unwrap(); + let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); - let o_r2 = machine.read_out_wire(o_w2).unwrap(); + let o_r2 = results.read_out_wire(o_w2).unwrap(); assert_eq!(o_r2, pv_true_or_false()); assert_eq!( Some(TailLoopTermination::BreaksAndContinues), - machine.tail_loop_terminates(&hugr, tail_loop.node()) + results.tail_loop_terminates(tail_loop.node()) ); - assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); } #[test] @@ -290,18 +285,17 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r1 = machine.read_out_wire(o_w1).unwrap(); + let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); - let o_r2 = machine.read_out_wire(o_w2).unwrap(); + let o_r2 = results.read_out_wire(o_w2).unwrap(); assert_eq!(o_r2, pv_false()); assert_eq!( Some(TailLoopTermination::BreaksAndContinues), - machine.tail_loop_terminates(&hugr, tail_loop.node()) + results.tail_loop_terminates(tail_loop.node()) ); - assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); } #[test] @@ -339,29 +333,28 @@ fn test_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant( 2, [PartialValue::new_variant(0, [])], )); - machine.run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); - let cond_r1 = machine + let cond_r1 = results .read_out_wire(cond_o1) .unwrap() .try_into_wire_value(&hugr, cond_o1) .unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(machine + assert!(results .read_out_wire(cond_o2) .unwrap() .try_into_wire_value(&hugr, cond_o2) .is_err()); - assert_eq!(machine.case_reachable(&hugr, case1.node()), Some(false)); // arg_pv is variant 1 or 2 only - assert_eq!(machine.case_reachable(&hugr, case2.node()), Some(true)); - assert_eq!(machine.case_reachable(&hugr, case3.node()), Some(true)); - assert_eq!(machine.case_reachable(&hugr, cond.node()), None); + assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only + assert_eq!(results.case_reachable(case2.node()), Some(true)); + assert_eq!(results.case_reachable(case3.node()), Some(true)); + assert_eq!(results.case_reachable(cond.node()), None); } // A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) @@ -474,24 +467,14 @@ fn test_cfg( #[case] out1: PartialValue, xor_and_cfg: Hugr, ) { - let mut machine = Machine::default(); - machine.run( - TestContext(Arc::new(&xor_and_cfg)), + let root = xor_and_cfg.root(); + let results = Machine::default().run( + TestContext(Arc::new(xor_and_cfg)), [(0.into(), inp0), (1.into(), inp1)], ); - assert_eq!( - machine - .read_out_wire(Wire::new(xor_and_cfg.root(), 0)) - .unwrap(), - out0 - ); - assert_eq!( - machine - .read_out_wire(Wire::new(xor_and_cfg.root(), 1)) - .unwrap(), - out1 - ); + assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); + assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); } #[rstest] @@ -522,13 +505,12 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let mut machine = Machine::default(); - machine.run( + let results = Machine::default().run( TestContext(Arc::new(&hugr)), [(0.into(), inp0), (1.into(), inp1)], ); - let [res0, res1] = [0, 1].map(|i| machine.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); + let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: assert_eq!(res0, out); assert_eq!(res1, out); From 34e82ede05e00d64a5e4cf8a5de7a310d72987d3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 09:51:53 +0100 Subject: [PATCH 138/281] Move try_into_wire_value => AnalysisResults.try_read_wire_value --- hugr-passes/src/dataflow/machine.rs | 28 +++++++++++++++++- hugr-passes/src/dataflow/partial_value.rs | 29 ------------------ hugr-passes/src/dataflow/test.rs | 36 ++++------------------- 3 files changed, 33 insertions(+), 60 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 4ac66686c..df2320ff1 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; +use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; use itertools::Itertools; use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; @@ -133,6 +133,32 @@ impl> AnalysisResults { } } +impl> AnalysisResults +where + Value: From, +{ + /// Reads a [Value] from an output wire, if the lattice value computed for it can be turned + /// into one. (The lattice value must be either a single [Value](Self::Value) or + /// a [Sum](PartialValue::PartialSum with a single known tag.) + /// + /// # Errors + /// `None` if the analysis did not result in a single value on that wire + /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] + /// + /// # Panics + /// + /// If a [Type] for the specified wire could not be extracted from the Hugr + pub fn try_read_wire_value(&self, w: Wire) -> Result> { + let v = self.read_out_wire(w).ok_or(None)?; + let (_, typ) = self + .context() + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + v.try_into_value(&typ) + } +} + /// Tells whether a loop iterates (never, always, sometimes) #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub enum TailLoopTermination { diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d3010ae45..5b5695dd0 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -2,7 +2,6 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; use hugr_core::ops::Value; use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; -use hugr_core::{HugrView, Wire}; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -327,34 +326,6 @@ impl TryFrom> for Value { } } -impl PartialValue -where - Value: From, -{ - /// Turns this instance into a [Value], if it is either a single [Value](Self::Value) or - /// a [Sum](PartialValue::PartialSum) with a single known tag, extracting the desired type - /// from a HugrView and Wire. - /// - /// # Errors - /// `None` if the analysis did not result in a single value on that wire - /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] - /// - /// # Panics - /// - /// If a [Type] for the specified wire could not be extracted from the Hugr - pub fn try_into_wire_value( - self, - hugr: &impl HugrView, - w: Wire, - ) -> Result> { - let (_, typ) = hugr - .out_value_types(w.node()) - .find(|(p, _)| *p == w.source()) - .unwrap(); - self.try_into_value(&typ) - } -} - impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e79a0024b..74d504b74 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -99,11 +99,7 @@ fn test_make_tuple() { let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let x = results - .read_out_wire(v3) - .unwrap() - .try_into_wire_value(&hugr, v3) - .unwrap(); + let x = results.try_read_wire_value(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); } @@ -119,17 +115,9 @@ fn test_unpack_tuple_const() { let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o1_r = results - .read_out_wire(o1) - .unwrap() - .try_into_wire_value(&hugr, o1) - .unwrap(); + let o1_r = results.try_read_wire_value(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r = results - .read_out_wire(o2) - .unwrap() - .try_into_wire_value(&hugr, o2) - .unwrap(); + let o2_r = results.try_read_wire_value(o2).unwrap(); assert_eq!(o2_r, Value::true_val()); } @@ -150,11 +138,7 @@ fn test_tail_loop_never_iterates() { let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r = results - .read_out_wire(tl_o) - .unwrap() - .try_into_wire_value(&hugr, tl_o) - .unwrap(); + let o_r = results.try_read_wire_value(tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::NeverContinues), @@ -339,17 +323,9 @@ fn test_conditional() { )); let results = Machine::default().run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); - let cond_r1 = results - .read_out_wire(cond_o1) - .unwrap() - .try_into_wire_value(&hugr, cond_o1) - .unwrap(); + let cond_r1 = results.try_read_wire_value(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results - .read_out_wire(cond_o2) - .unwrap() - .try_into_wire_value(&hugr, cond_o2) - .is_err()); + assert!(results.try_read_wire_value(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); From 015707f98a21d99cb0729e4be776604d2b16b4cf Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 10:16:14 +0100 Subject: [PATCH 139/281] doc fixes and fix comment --- hugr-passes/src/dataflow/machine.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index df2320ff1..4ddde626e 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -14,7 +14,7 @@ pub struct Machine>(AscentProgram); /// Results of a dataflow analysis. pub struct AnalysisResults>( - AscentProgram, // Already run - kept for tests/debug + AscentProgram, // Already run HashMap>, ); @@ -138,8 +138,8 @@ where Value: From, { /// Reads a [Value] from an output wire, if the lattice value computed for it can be turned - /// into one. (The lattice value must be either a single [Value](Self::Value) or - /// a [Sum](PartialValue::PartialSum with a single known tag.) + /// into one. (The lattice value must be either a single [Value](PartialValue::Value) or + /// a [Sum](PartialValue::PartialSum) with a single known tag.) /// /// # Errors /// `None` if the analysis did not result in a single value on that wire @@ -147,7 +147,7 @@ where /// /// # Panics /// - /// If a [Type] for the specified wire could not be extracted from the Hugr + /// If a [Type](hugr_core::types::Type) for the specified wire could not be extracted from the Hugr pub fn try_read_wire_value(&self, w: Wire) -> Result> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self From ff39f7d92831ddcc9babf6bf947414b1851ae15a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 10:21:22 +0100 Subject: [PATCH 140/281] Try to make clippy happy --- hugr-passes/src/dataflow/machine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 4ddde626e..ccd13e517 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -9,7 +9,7 @@ use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values /// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via -/// [read_out_wire](AnalysisResults::read_out_wire) +/// [read_out_wire](AnalysisResults::read_out_wire) pub struct Machine>(AscentProgram); /// Results of a dataflow analysis. From 502d4a2d0d61c3e5eb84ee567bc2969cdef560d7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 18:59:16 +0100 Subject: [PATCH 141/281] Fix/make-compile total_context.rs --- hugr-passes/src/dataflow/total_context.rs | 49 +++++++++++++---------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index dc3c7a69a..9dcfa8a80 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -1,18 +1,24 @@ use std::hash::Hash; +use hugr_core::ops::ExtensionOp; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; +use super::partial_value::{AbstractValue, PartialValue, Sum}; use super::DFContext; /// A simpler interface like [DFContext] but where the context only cares about -/// values that are completely known (in the lattice `V`) -/// rather than e.g. Sums potentially of two variants each of known values. +/// values that are completely known (in the lattice `V`) rather than partially +/// (e.g. no [PartialSum]s of more than one variant, no top/bottom) pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { - type InterpretableVal: TryFrom>; + /// Representation of a (single, non-partial) value usable for interpretation + type InterpretableVal: From + TryFrom>; + + /// Interpret an (extension) operation given total values for some of the in-ports + /// `ins` will be a non-empty slice with distinct [IncomingPort]s. fn interpret_leaf_op( &self, node: Node, + e: &ExtensionOp, ins: &[(IncomingPort, Self::InterpretableVal)], ) -> Vec<(OutgoingPort, V)>; } @@ -21,32 +27,33 @@ impl> DFContext for T { fn interpret_leaf_op( &self, node: Node, + e: &ExtensionOp, ins: &[PartialValue], - ) -> Option>> { + outs: &mut [PartialValue], + ) { let op = self.get_optype(node); - let sig = op.dataflow_signature()?; + let Some(sig) = op.dataflow_signature() else { + return; + }; let known_ins = sig .input_types() .iter() .enumerate() .zip(ins.iter()) .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_value(ty) - // Discard PVs which don't produce ValueOrSum, i.e. Bottom/Top :-) - .ok() - // And discard any ValueOrSum that don't produce V - this is a bit silent :-( - .and_then(|v_s| T::InterpretableVal::try_from(v_s).ok()) - .map(|v| (IncomingPort::from(i), v)) + let v = match pv { + PartialValue::Bottom | PartialValue::Top => None, + PartialValue::Value(v) => Some(v.clone().into()), + PartialValue::PartialSum(ps) => T::InterpretableVal::try_from( + ps.clone().try_into_value::(ty).ok()?, + ) + .ok(), + }?; + Some((IncomingPort::from(i), v)) }) .collect::>(); - let known_outs = self.interpret_leaf_op(node, &known_ins); - (!known_outs.is_empty()).then(|| { - let mut res = vec![PartialValue::Bottom; sig.output_count()]; - for (p, v) in known_outs { - res[p.index()] = v.into(); - } - res - }) + for (p, v) in self.interpret_leaf_op(node, e, &known_ins) { + outs[p.index()] = PartialValue::Value(v); + } } } From 25c4a825f5fcbb282871538e2028e3307e631d8a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 19:49:26 +0100 Subject: [PATCH 142/281] total_context: return PartialValue, as need some repr of Sum --- hugr-passes/src/dataflow/total_context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 9dcfa8a80..986e3ff50 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -20,7 +20,7 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { node: Node, e: &ExtensionOp, ins: &[(IncomingPort, Self::InterpretableVal)], - ) -> Vec<(OutgoingPort, V)>; + ) -> Vec<(OutgoingPort, PartialValue)>; } impl> DFContext for T { @@ -53,7 +53,7 @@ impl> DFContext for T { }) .collect::>(); for (p, v) in self.interpret_leaf_op(node, e, &known_ins) { - outs[p.index()] = PartialValue::Value(v); + outs[p.index()] = v; } } } From e7f6ad1686e6050132c2985628e920efc50ae653 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 19:49:50 +0100 Subject: [PATCH 143/281] const_fold2: disable missing-docs warning --- hugr-passes/src/const_fold2.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 58f285d43..93b772d88 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,4 +1,3 @@ -#![warn(missing_docs)] //! An (example) use of the [super::dataflow](dataflow-analysis framework) //! to perform constant-folding. From 8a8db968f453563af40087bcfccb9779fa429dc8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 19:57:23 +0100 Subject: [PATCH 144/281] fix context, value_handle --- hugr-passes/src/const_fold2/context.rs | 37 +++------ hugr-passes/src/const_fold2/value_handle.rs | 85 +++++++++++---------- 2 files changed, 56 insertions(+), 66 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 6338629df..886c8da71 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -2,11 +2,11 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::ops::{OpType, Value}; +use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::value_handle::{ValueHandle, ValueKey}; -use crate::dataflow::TotalContext; +use crate::dataflow::{PartialValue, TotalContext}; /// A [context](crate::dataflow::DFContext) for doing analysis with [ValueHandle]s. /// Interprets [LoadConstant](OpType::LoadConstant) nodes, @@ -71,31 +71,14 @@ impl TotalContext for HugrValueContext { fn interpret_leaf_op( &self, n: Node, + op: &ExtensionOp, ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort, ValueHandle)> { - match self.0.get_optype(n) { - OpType::LoadConstant(load_op) => { - assert!(ins.is_empty()); // static edge, so need to find constant - let const_node = self - .0 - .single_linked_output(n, load_op.constant_port()) - .unwrap() - .0; - let const_op = self.0.get_optype(const_node).as_const().unwrap(); - vec![( - OutgoingPort::from(0), - ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())), - )] - } - OpType::ExtensionOp(op) => { - let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); - op.constant_fold(&ins).map_or(Vec::new(), |outs| { - outs.into_iter() - .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))) - .collect() - }) - } - _ => vec![], - } + ) -> Vec<(OutgoingPort, PartialValue)> { + let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); + op.constant_fold(&ins).map_or(Vec::new(), |outs| { + outs.into_iter() + .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), v))) + .collect() + }) } } diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index bbcd25129..05d57bfaf 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -2,12 +2,13 @@ use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1 use std::hash::{Hash, Hasher}; use std::sync::Arc; -use hugr_core::ops::constant::{CustomConst, Sum}; +use hugr_core::ops::constant::{CustomConst, OpaqueValue}; use hugr_core::ops::Value; use hugr_core::types::Type; -use hugr_core::Node; +use hugr_core::{Hugr, HugrView, Node}; +use itertools::Either; -use crate::dataflow::AbstractValue; +use crate::dataflow::{AbstractValue, PartialValue}; #[derive(Clone, Debug)] pub struct HashedConst { @@ -69,37 +70,35 @@ impl ValueKey { } #[derive(Clone, Debug)] -pub struct ValueHandle(ValueKey, Arc); +pub struct ValueHandle(ValueKey, Arc>>); impl ValueHandle { - pub fn new(key: ValueKey, value: Arc) -> Self { - Self(key, value) - } - - pub fn value(&self) -> &Value { - self.1.as_ref() + pub fn new(key: ValueKey, value: Value) -> PartialValue { + match value { + Value::Extension { e } => PartialValue::Value(Self(key, Arc::new(Either::Left(e)))), + Value::Function { hugr } => { + PartialValue::Value(Self(key, Arc::new(Either::Right(hugr)))) + } + Value::Sum(sum) => PartialValue::new_variant( + sum.tag, + sum.values + .into_iter() + .enumerate() + .map(|(i, v)| Self::new(key.clone().field(i), v)), + ), + } } pub fn get_type(&self) -> Type { - self.1.get_type() - } -} - -impl AbstractValue for ValueHandle { - fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { - match self.value() { - Value::Sum(Sum { tag, values, .. }) => Some(( - *tag, - values - .iter() - .enumerate() - .map(|(i, v)| Self(self.0.clone().field(i), Arc::new(v.clone()))), - )), - _ => None, + match &*self.1 { + Either::Left(e) => e.get_type(), + Either::Right(bh) => Type::new_function(bh.inner_function_type().unwrap()), } } } +impl AbstractValue for ValueHandle {} + impl PartialEq for ValueHandle { fn eq(&self, other: &Self) -> bool { // If the keys are equal, we return true since the values must have the @@ -125,7 +124,10 @@ impl Hash for ValueHandle { impl From for Value { fn from(value: ValueHandle) -> Self { - (*value.1).clone() + match Arc::>::unwrap_or_clone(value.1) { + Either::Left(e) => Value::Extension { e }, + Either::Right(hugr) => Value::Function { hugr }, + } } } @@ -143,6 +145,7 @@ mod test { }, types::SumType, }; + use itertools::Itertools; use super::*; @@ -201,22 +204,26 @@ mod test { #[test] fn value_handle_eq() { let k_i = ConstInt::new_u(4, 2).unwrap(); - let subject_val = Arc::new( - Value::sum( - 0, - [k_i.clone().into()], - SumType::new([vec![k_i.get_type()], vec![]]), - ) - .unwrap(), - ); + let st = SumType::new([vec![k_i.get_type()], vec![]]); + let subject_val = Value::sum(0, [k_i.clone().into()], st).unwrap(); let k1 = ValueKey::try_new(ConstString::new("foo".to_string())).unwrap(); - let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); - let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); + let PartialValue::PartialSum(ps1) = ValueHandle::new(k1.clone(), subject_val.clone()) + else { + panic!() + }; + let (_tag, fields) = ps1.0.into_iter().exactly_one().unwrap(); + let PartialValue::Value(vh1) = fields.into_iter().exactly_one().unwrap() else { + panic!() + }; + + let PartialValue::Value(v2) = ValueHandle::new(k1.clone(), Value::extension(k_i).into()) + else { + panic!() + }; - let fields = v1.as_sum().unwrap().1.collect::>(); // we do not compare the value, just the key - assert_ne!(fields[0], v2); - assert_eq!(fields[0].value(), v2.value()); + assert_ne!(vh1, v2); + assert_eq!(vh1.1, v2.1); } } From ada7ee1d1afc0fe627179f09ffe515f6889132b9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 12:02:22 +0100 Subject: [PATCH 145/281] Use ascent_run to drop context from all the relations. Lots cleanup to follow --- hugr-passes/src/dataflow/datalog.rs | 366 +++++++++++++++------------- hugr-passes/src/dataflow/machine.rs | 57 ++--- 2 files changed, 219 insertions(+), 204 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index f664a84d4..5cb9afa32 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -8,7 +8,7 @@ )] use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::zip_eq; +use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::hash::Hash; use std::ops::{Index, IndexMut}; @@ -27,183 +27,197 @@ pub enum IO { Output, } -ascent::ascent! { - pub(super) struct AscentProgram>; - relation context(C); - - relation node(C, Node); - relation in_wire(C, Node, IncomingPort); - relation out_wire(C, Node, OutgoingPort); - relation parent_of_node(C, Node, Node); - relation io_node(C, Node, Node, IO); - lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice in_wire_value(C, Node, IncomingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); - - node(c, n) <-- context(c), for n in c.nodes(); - - in_wire(c, n,p) <-- node(c, n), for (p,_) in c.in_value_types(*n); // Note, gets connected inports only - out_wire(c, n,p) <-- node(c, n), for (p,_) in c.out_value_types(*n); // (and likewise) - - parent_of_node(c, parent, child) <-- - node(c, child), if let Some(parent) = c.get_parent(*child); - - io_node(c, parent, child, io) <-- node(c, parent), - if let Some([i,o]) = c.get_io(*parent), - for (child,io) in [(i,IO::Input),(o,IO::Output)]; - - // Initialize all wires to bottom - out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); - - in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), - if let Some((m,op)) = c.single_linked_output(*n, *ip), - out_wire_value(c, m, op, v); - - // We support prepopulating in_wire_value via in_wire_value_proto. - relation in_wire_value_proto(Node, IncomingPort, PV); - in_wire_value(c, n, p, PV::bottom()) <-- in_wire(c, n,p); - in_wire_value(c, n, p, v) <-- node(c,n), - if let Some(sig) = c.signature(*n), - for p in sig.input_ports(), - in_wire_value_proto(n, p, v); - - node_in_value_row(c, n, ValueRow::new(sig.input_count())) <-- node(c, n), if let Some(sig) = c.signature(*n); - node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); - - out_wire_value(c, n, p, v) <-- - node(c, n), - let op_t = c.get_optype(*n), - if !op_t.is_container(), - if let Some(sig) = op_t.dataflow_signature(), - node_in_value_row(c, n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count()), - for (p,v) in (0..).map(OutgoingPort::from).zip(outs); - - // DFG - relation dfg_node(C, Node); - dfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_dfg(); - - out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), - io_node(c, dfg, i, IO::Input), in_wire_value(c, dfg, p, v); - - out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), - io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v); - - - // TailLoop - - // inputs of tail loop propagate to Input node of child region - out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- node(c, tl), - if c.get_optype(*tl).is_tail_loop(), - io_node(c,tl,i, IO::Input), - in_wire_value(c, tl, p, v); - - // Output node of child region propagate to Input node of child region - out_wire_value(c, in_n, OutgoingPort::from(out_p), v) <-- node(c, tl_n), - if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - io_node(c,tl_n,in_n, IO::Input), - io_node(c,tl_n,out_n, IO::Output), - node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - - if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 - for (out_p, v) in fields.enumerate(); - - // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl_n, OutgoingPort::from(out_p), v) <-- node(c, tl_n), - if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - io_node(c,tl_n,out_n, IO::Output), - node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 - for (out_p, v) in fields.enumerate(); - - // Conditional - relation conditional_node(C, Node); - relation case_node(C,Node,usize, Node); - - conditional_node (c,n)<-- node(c, n), if c.get_optype(*n).is_conditional(); - case_node(c,cond,i, case) <-- conditional_node(c,cond), - for (i, case) in c.children(*cond).enumerate(), - if c.get_optype(case).is_case(); - - // inputs of conditional propagate into case nodes - out_wire_value(c, i_node, OutgoingPort::from(out_p), v) <-- - case_node(c, cond, case_index, case), - io_node(c, case, i_node, IO::Input), - node_in_value_row(c, cond, in_row), - let conditional = c.get_optype(*cond).as_conditional().unwrap(), - if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), - for (out_p, v) in fields.enumerate(); - - // outputs of case nodes propagate to outputs of conditional *if* case reachable - out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- - case_node(c, cond, _, case), - case_reachable(c, cond, case), - io_node(c, case, o, IO::Output), - in_wire_value(c, o, o_p, v); - - relation case_reachable(C, Node, Node); - case_reachable(c, cond, case) <-- case_node(c,cond,i,case), - in_wire_value(c, cond, IncomingPort::from(0), v), - if v.supports_tag(*i); - - // CFG - relation cfg_node(C, Node); - relation dfb_block(C, Node, Node); - cfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_cfg(); - dfb_block(c,cfg,blk) <-- cfg_node(c, cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); - - // Reachability - relation bb_reachable(C, Node, Node); - bb_reachable(c, cfg, entry) <-- cfg_node(c, cfg), if let Some(entry) = c.children(*cfg).next(); - bb_reachable(c, cfg, bb) <-- cfg_node(c, cfg), - bb_reachable(c, cfg, pred), - io_node(c, pred, pred_out, IO::Output), - in_wire_value(c, pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in c.output_neighbours(*pred).enumerate(), - if predicate.supports_tag(tag); - - // Where do the values "fed" along a control-flow edge come out? - relation _cfg_succ_dest(C, Node, Node, Node); - _cfg_succ_dest(c, cfg, blk, inp) <-- dfb_block(c, cfg, blk), io_node(c, blk, inp, IO::Input); - _cfg_succ_dest(c, cfg, exit, cfg) <-- cfg_node(c, cfg), if let Some(exit) = c.children(*cfg).nth(1); - - // Inputs of CFG propagate to entry block - out_wire_value(c, i_node, OutgoingPort::from(p.index()), v) <-- - cfg_node(c, cfg), - if let Some(entry) = c.children(*cfg).next(), - io_node(c, entry, i_node, IO::Input), - in_wire_value(c, cfg, p, v); - - // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself - out_wire_value(c, dest, OutgoingPort::from(out_p), v) <-- - dfb_block(c, cfg, pred), - bb_reachable(c, cfg, pred), - let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), - for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), - io_node(c, pred, out_n, IO::Output), - _cfg_succ_dest(c, cfg, succ, dest), - node_in_value_row(c, out_n, out_in_row), - if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), - for (out_p, v) in fields.enumerate(); - - // Call - relation func_call(C, Node, Node); - func_call(c, call, func_defn) <-- - node(c, call), - if c.get_optype(*call).is_call(), - if let Some(func_defn) = c.static_source(*call); - - out_wire_value(c, inp, OutgoingPort::from(p.index()), v) <-- - func_call(c, call, func), - io_node(c, func, inp, IO::Input), - in_wire_value(c, call, p, v); - - out_wire_value(c, call, OutgoingPort::from(p.index()), v) <-- - func_call(c, call, func), - io_node(c, func, outp, IO::Output), - in_wire_value(c, outp, p, v); +pub(super) struct DatalogResults { + pub in_wire_value: Vec<(Node, IncomingPort, PV)>, + pub out_wire_value: Vec<(Node, OutgoingPort, PV)>, + pub case_reachable: Vec<(Node, Node)>, + pub bb_reachable: Vec<(Node, Node)>, } +pub(super) fn run_datalog>( + in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, + c: &C, +) -> DatalogResults { + let all_results = ascent::ascent_run! { + pub(super) struct AscentProgram; + relation node(Node); + relation in_wire(Node, IncomingPort); + relation out_wire(Node, OutgoingPort); + relation parent_of_node(Node, Node); + relation io_node(Node, Node, IO); + lattice out_wire_value(Node, OutgoingPort, PV); + lattice in_wire_value(Node, IncomingPort, PV); + lattice node_in_value_row(Node, ValueRow); + + node(n) <-- for n in c.nodes(); + + in_wire(n, p) <-- node(n), for (p,_) in c.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in c.out_value_types(*n); // (and likewise) + + parent_of_node(parent, child) <-- + node(child), if let Some(parent) = c.get_parent(*child); + + io_node(parent, child, io) <-- node(parent), + if let Some([i, o]) = c.get_io(*parent), + for (child,io) in [(i,IO::Input),(o,IO::Output)]; + + // Initialize all wires to bottom + out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); + + in_wire_value(n, ip, v) <-- in_wire(n, ip), + if let Some((m, op)) = c.single_linked_output(*n, *ip), + out_wire_value(m, op, v); + + // We support prepopulating in_wire_value via in_wire_value_proto. + in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); + in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), + node(n), + if let Some(sig) = c.signature(*n), + if sig.input_ports().contains(p); + + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = c.signature(*n); + node_in_value_row(n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + + out_wire_value(n, p, v) <-- + node(n), + let op_t = c.get_optype(*n), + if !op_t.is_container(), + if let Some(sig) = op_t.dataflow_signature(), + node_in_value_row(n, vs), + if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count()), + for (p, v) in (0..).map(OutgoingPort::from).zip(outs); + + // DFG + relation dfg_node(Node); + dfg_node(n) <-- node(n), if c.get_optype(*n).is_dfg(); + + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + io_node(dfg, i, IO::Input), in_wire_value(dfg, p, v); + + out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + io_node(dfg, o, IO::Output), in_wire_value(o, p, v); + + + // TailLoop + + // inputs of tail loop propagate to Input node of child region + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), + if c.get_optype(*tl).is_tail_loop(), + io_node(tl, i, IO::Input), + in_wire_value(tl, p, v); + + // Output node of child region propagate to Input node of child region + out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = c.get_optype(*tl).as_tail_loop(), + io_node(tl, in_n, IO::Input), + io_node(tl, out_n, IO::Output), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node + + if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 + for (out_p, v) in fields.enumerate(); + + // Output node of child region propagate to outputs of tail loop + out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = c.get_optype(*tl).as_tail_loop(), + io_node(tl, out_n, IO::Output), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node + if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 + for (out_p, v) in fields.enumerate(); + + // Conditional + relation conditional_node(Node); + relation case_node(Node, usize, Node); + + conditional_node(n)<-- node(n), if c.get_optype(*n).is_conditional(); + case_node(cond, i, case) <-- conditional_node(cond), + for (i, case) in c.children(*cond).enumerate(), + if c.get_optype(case).is_case(); + + // inputs of conditional propagate into case nodes + out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- + case_node(cond, case_index, case), + io_node(case, i_node, IO::Input), + node_in_value_row(cond, in_row), + let conditional = c.get_optype(*cond).as_conditional().unwrap(), + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + for (out_p, v) in fields.enumerate(); + + // outputs of case nodes propagate to outputs of conditional *if* case reachable + out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(cond, _, case), + case_reachable(cond, case), + io_node(case, o, IO::Output), + in_wire_value(o, o_p, v); + + relation case_reachable(Node, Node); + case_reachable(cond, case) <-- case_node(cond, i, case), + in_wire_value(cond, IncomingPort::from(0), v), + if v.supports_tag(*i); + + // CFG + relation cfg_node(Node); + relation dfb_block(Node, Node); + cfg_node(n) <-- node(n), if c.get_optype(*n).is_cfg(); + dfb_block(cfg, blk) <-- cfg_node(cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); + + // Reachability + relation bb_reachable(Node, Node); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = c.children(*cfg).next(); + bb_reachable(cfg, bb) <-- cfg_node(cfg), + bb_reachable(cfg, pred), + io_node(pred, pred_out, IO::Output), + in_wire_value(pred_out, IncomingPort::from(0), predicate), + for (tag, bb) in c.output_neighbours(*pred).enumerate(), + if predicate.supports_tag(tag); + + // Where do the values "fed" along a control-flow edge come out? + relation _cfg_succ_dest(Node, Node, Node); + _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), io_node(blk, inp, IO::Input); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = c.children(*cfg).nth(1); + + // Inputs of CFG propagate to entry block + out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- + cfg_node(cfg), + if let Some(entry) = c.children(*cfg).next(), + io_node(entry, i_node, IO::Input), + in_wire_value(cfg, p, v); + + // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself + out_wire_value(dest, OutgoingPort::from(out_p), v) <-- + dfb_block(cfg, pred), + bb_reachable(cfg, pred), + let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), + for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), + io_node(pred, out_n, IO::Output), + _cfg_succ_dest(cfg, succ, dest), + node_in_value_row(out_n, out_in_row), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + for (out_p, v) in fields.enumerate(); + + // Call + relation func_call(Node, Node); + func_call(call, func_defn) <-- + node(call), + if c.get_optype(*call).is_call(), + if let Some(func_defn) = c.static_source(*call); + + out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + io_node(func, inp, IO::Input), + in_wire_value(call, p, v); + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + io_node(func, outp, IO::Output), + in_wire_value(outp, p, v); + }; + DatalogResults { + in_wire_value: all_results.in_wire_value, + out_wire_value: all_results.out_wire_value, + case_reachable: all_results.case_reachable, + bb_reachable: all_results.bb_reachable, + } +} fn propagate_leaf_op( c: &impl DFContext, n: Node, diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index ccd13e517..f93384f7d 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,35 +1,36 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use itertools::Itertools; -use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; +use super::datalog::{run_datalog, DatalogResults}; +use super::{AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values /// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via /// [read_out_wire](AnalysisResults::read_out_wire) -pub struct Machine>(AscentProgram); +pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); -/// Results of a dataflow analysis. -pub struct AnalysisResults>( - AscentProgram, // Already run - HashMap>, -); +/// Results of a dataflow analysis, packaged with context for easy inspection +pub struct AnalysisResults> { + context: C, + results: DatalogResults, + out_wire_values: HashMap>, +} /// derived-Default requires the context to be Defaultable, which is unnecessary -impl> Default for Machine { +impl Default for Machine { fn default() -> Self { Self(Default::default()) } } -impl> Machine { +impl Machine { /// Provide initial values for some wires. // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { - self.0.in_wire_value_proto.extend( + self.0.extend( h.linked_inputs(wire.node(), wire.source()) .map(|(n, inp)| (n, inp, value.clone())), ); @@ -40,36 +41,36 @@ impl> Machine { /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. - pub fn run( + pub fn run>( mut self, context: C, in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = context.root(); self.0 - .in_wire_value_proto .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - self.0.context.push((context,)); - self.0.run(); - let results = self - .0 + let results = run_datalog(self.0, &context); + let out_wire_values = results .out_wire_value .iter() - .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) + .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); - AnalysisResults(self.0, results) + AnalysisResults { + context, + results, + out_wire_values, + } } } impl> AnalysisResults { fn context(&self) -> &C { - let (c,) = self.0.context.iter().exactly_one().ok().unwrap(); - c + &self.context } /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { - self.1.get(&w).cloned() + self.out_wire_values.get(&w).cloned() } /// Tells whether a [TailLoop] node can terminate, i.e. whether @@ -82,10 +83,10 @@ impl> AnalysisResults { hugr.get_optype(node).as_tail_loop()?; let [_, out] = hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( - self.0 + self.results .in_wire_value .iter() - .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .find_map(|(n, p, v)| (*n == out && p.index() == 0).then_some(v)) .unwrap(), )) } @@ -103,10 +104,10 @@ impl> AnalysisResults { let cond = hugr.get_parent(case)?; hugr.get_optype(cond).as_conditional()?; Some( - self.0 + self.results .case_reachable .iter() - .any(|(_, cond2, case2)| &cond == cond2 && &case == case2), + .any(|(cond2, case2)| &cond == cond2 && &case == case2), ) } @@ -125,10 +126,10 @@ impl> AnalysisResults { return None; }; Some( - self.0 + self.results .bb_reachable .iter() - .any(|(_, cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), + .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), ) } } From 7c02d41ae6055cd6082f425f36d21270585cf88c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 15:40:50 +0100 Subject: [PATCH 146/281] DFContext does not Deref, pass Hugr separately --- hugr-passes/src/dataflow.rs | 5 +-- hugr-passes/src/dataflow/datalog.rs | 70 +++++++++++++++-------------- hugr-passes/src/dataflow/machine.rs | 46 +++++++++---------- hugr-passes/src/dataflow/test.rs | 69 +++++----------------------- 4 files changed, 70 insertions(+), 120 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index e04b4c2a8..e24960988 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -12,11 +12,10 @@ pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::{Hugr, Node}; -use std::hash::Hash; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { +pub trait DFContext { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] @@ -61,7 +60,7 @@ pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { } fn traverse_value( - s: &impl DFContext, + s: &(impl DFContext + ?Sized), n: Node, fields: &mut Vec, cst: &Value, diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 5cb9afa32..eb00e1c6a 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -34,9 +34,10 @@ pub(super) struct DatalogResults { pub bb_reachable: Vec<(Node, Node)>, } -pub(super) fn run_datalog>( +pub(super) fn run_datalog( in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, - c: &C, + c: &impl DFContext, + hugr: &impl HugrView, ) -> DatalogResults { let all_results = ascent::ascent_run! { pub(super) struct AscentProgram; @@ -49,47 +50,47 @@ pub(super) fn run_datalog>( lattice in_wire_value(Node, IncomingPort, PV); lattice node_in_value_row(Node, ValueRow); - node(n) <-- for n in c.nodes(); + node(n) <-- for n in hugr.nodes(); - in_wire(n, p) <-- node(n), for (p,_) in c.in_value_types(*n); // Note, gets connected inports only - out_wire(n, p) <-- node(n), for (p,_) in c.out_value_types(*n); // (and likewise) + in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise) parent_of_node(parent, child) <-- - node(child), if let Some(parent) = c.get_parent(*child); + node(child), if let Some(parent) = hugr.get_parent(*child); io_node(parent, child, io) <-- node(parent), - if let Some([i, o]) = c.get_io(*parent), + if let Some([i, o]) = hugr.get_io(*parent), for (child,io) in [(i,IO::Input),(o,IO::Output)]; // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); in_wire_value(n, ip, v) <-- in_wire(n, ip), - if let Some((m, op)) = c.single_linked_output(*n, *ip), + if let Some((m, op)) = hugr.single_linked_output(*n, *ip), out_wire_value(m, op, v); // We support prepopulating in_wire_value via in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), - if let Some(sig) = c.signature(*n), + if let Some(sig) = hugr.signature(*n), if sig.input_ports().contains(p); - node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = c.signature(*n); - node_in_value_row(n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); + node_in_value_row(n, ValueRow::single_known(hugr.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); out_wire_value(n, p, v) <-- node(n), - let op_t = c.get_optype(*n), + let op_t = hugr.get_optype(*n), if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(c, hugr, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG relation dfg_node(Node); - dfg_node(n) <-- node(n), if c.get_optype(*n).is_dfg(); + dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), io_node(dfg, i, IO::Input), in_wire_value(dfg, p, v); @@ -102,13 +103,13 @@ pub(super) fn run_datalog>( // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), - if c.get_optype(*tl).is_tail_loop(), + if hugr.get_optype(*tl).is_tail_loop(), io_node(tl, i, IO::Input), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = c.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), io_node(tl, in_n, IO::Input), io_node(tl, out_n, IO::Output), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node @@ -118,7 +119,7 @@ pub(super) fn run_datalog>( // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = c.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), io_node(tl, out_n, IO::Output), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 @@ -128,17 +129,17 @@ pub(super) fn run_datalog>( relation conditional_node(Node); relation case_node(Node, usize, Node); - conditional_node(n)<-- node(n), if c.get_optype(*n).is_conditional(); + conditional_node(n)<-- node(n), if hugr.get_optype(*n).is_conditional(); case_node(cond, i, case) <-- conditional_node(cond), - for (i, case) in c.children(*cond).enumerate(), - if c.get_optype(case).is_case(); + for (i, case) in hugr.children(*cond).enumerate(), + if hugr.get_optype(case).is_case(); // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), io_node(case, i_node, IO::Input), node_in_value_row(cond, in_row), - let conditional = c.get_optype(*cond).as_conditional().unwrap(), + let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -157,28 +158,28 @@ pub(super) fn run_datalog>( // CFG relation cfg_node(Node); relation dfb_block(Node, Node); - cfg_node(n) <-- node(n), if c.get_optype(*n).is_cfg(); - dfb_block(cfg, blk) <-- cfg_node(cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); + cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); + dfb_block(cfg, blk) <-- cfg_node(cfg), for blk in hugr.children(*cfg), if hugr.get_optype(blk).is_dataflow_block(); // Reachability relation bb_reachable(Node, Node); - bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = c.children(*cfg).next(); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), io_node(pred, pred_out, IO::Output), in_wire_value(pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in c.output_neighbours(*pred).enumerate(), + for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(Node, Node, Node); _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), io_node(blk, inp, IO::Input); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = c.children(*cfg).nth(1); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), - if let Some(entry) = c.children(*cfg).next(), + if let Some(entry) = hugr.children(*cfg).next(), io_node(entry, i_node, IO::Input), in_wire_value(cfg, p, v); @@ -186,8 +187,8 @@ pub(super) fn run_datalog>( out_wire_value(dest, OutgoingPort::from(out_p), v) <-- dfb_block(cfg, pred), bb_reachable(cfg, pred), - let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), - for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), + let df_block = hugr.get_optype(*pred).as_dataflow_block().unwrap(), + for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), io_node(pred, out_n, IO::Output), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), @@ -198,8 +199,8 @@ pub(super) fn run_datalog>( relation func_call(Node, Node); func_call(call, func_defn) <-- node(call), - if c.get_optype(*call).is_call(), - if let Some(func_defn) = c.static_source(*call); + if hugr.get_optype(*call).is_call(), + if let Some(func_defn) = hugr.static_source(*call); out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), @@ -220,11 +221,12 @@ pub(super) fn run_datalog>( } fn propagate_leaf_op( c: &impl DFContext, + hugr: &impl HugrView, n: Node, ins: &[PV], num_outs: usize, ) -> Option> { - match c.get_optype(n) { + match hugr.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be // easily available for reuse. @@ -247,11 +249,11 @@ fn propagate_leaf_op( OpType::Const(_) => None, // handled by LoadConstant: OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant - let const_node = c + let const_node = hugr .single_linked_output(n, load_op.constant_port()) .unwrap() .0; - let const_val = c.get_optype(const_node).as_const().unwrap().value(); + let const_val = hugr.get_optype(const_node).as_const().unwrap().value(); Some(ValueRow::single_known( 1, 0, diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index f93384f7d..f7f97463d 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -13,8 +13,8 @@ use super::{AbstractValue, DFContext, PartialValue}; pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); /// Results of a dataflow analysis, packaged with context for easy inspection -pub struct AnalysisResults> { - context: C, +pub struct AnalysisResults { + hugr: H, results: DatalogResults, out_wire_values: HashMap>, } @@ -41,33 +41,30 @@ impl Machine { /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. - pub fn run>( + pub fn run( mut self, - context: C, + context: &impl DFContext, + hugr: H, in_values: impl IntoIterator)>, - ) -> AnalysisResults { - let root = context.root(); + ) -> AnalysisResults { + let root = hugr.root(); self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let results = run_datalog(self.0, &context); + let results = run_datalog(self.0, context, &hugr); let out_wire_values = results .out_wire_value .iter() .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - context, + hugr, results, out_wire_values, } } } -impl> AnalysisResults { - fn context(&self) -> &C { - &self.context - } - +impl AnalysisResults { /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -79,9 +76,8 @@ impl> AnalysisResults { /// /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates(&self, node: Node) -> Option { - let hugr = self.context(); - hugr.get_optype(node).as_tail_loop()?; - let [_, out] = hugr.get_io(node).unwrap(); + self.hugr.get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( self.results .in_wire_value @@ -99,10 +95,9 @@ impl> AnalysisResults { /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, case: Node) -> Option { - let hugr = self.context(); - hugr.get_optype(case).as_case()?; - let cond = hugr.get_parent(case)?; - hugr.get_optype(cond).as_conditional()?; + self.hugr.get_optype(case).as_case()?; + let cond = self.hugr.get_parent(case)?; + self.hugr.get_optype(cond).as_conditional()?; Some( self.results .case_reachable @@ -118,10 +113,9 @@ impl> AnalysisResults { /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, bb: Node) -> Option { - let hugr = self.context(); - let cfg = hugr.get_parent(bb)?; // Not really required...?? - hugr.get_optype(cfg).as_cfg()?; - let t = hugr.get_optype(bb); + let cfg = self.hugr.get_parent(bb)?; // Not really required...?? + self.hugr.get_optype(cfg).as_cfg()?; + let t = self.hugr.get_optype(bb); if !t.is_dataflow_block() && !t.is_exit_block() { return None; }; @@ -134,7 +128,7 @@ impl> AnalysisResults { } } -impl> AnalysisResults +impl AnalysisResults where Value: From, { @@ -152,7 +146,7 @@ where pub fn try_read_wire_value(&self, w: Wire) -> Result> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self - .context() + .hugr .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 74d504b74..fe4af1c46 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,6 +1,3 @@ -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; @@ -26,49 +23,9 @@ enum Void {} impl AbstractValue for Void {} -struct TestContext(Arc); - -// Deriving Clone requires H:HugrView to implement Clone, -// but we don't need that as we only clone the Arc. -impl Clone for TestContext { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl std::ops::Deref for TestContext { - type Target = hugr_core::Hugr; - - fn deref(&self) -> &Self::Target { - self.0.base_hugr() - } -} - -// Any value used in an Ascent program must be hashable. -// However, there should only be one DFContext, so its hash is immaterial. -impl Hash for TestContext { - fn hash(&self, _state: &mut I) {} -} - -impl PartialEq for TestContext { - fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - true - } -} - -impl Eq for TestContext {} - -impl PartialOrd for TestContext { - fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - Some(std::cmp::Ordering::Equal) - } -} +struct TestContext; -impl DFContext for TestContext {} +impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { @@ -97,7 +54,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let x = results.try_read_wire_value(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -113,7 +70,7 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o1_r = results.try_read_wire_value(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -136,7 +93,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r = results.try_read_wire_value(tl_o).unwrap(); assert_eq!(o_r, r_v); @@ -165,7 +122,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -203,7 +160,7 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -269,7 +226,7 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -321,7 +278,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); + let results = Machine::default().run(&TestContext, &hugr, [(0.into(), arg_pv)]); let cond_r1 = results.try_read_wire_value(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); @@ -445,7 +402,8 @@ fn test_cfg( ) { let root = xor_and_cfg.root(); let results = Machine::default().run( - TestContext(Arc::new(xor_and_cfg)), + &TestContext, + &xor_and_cfg, [(0.into(), inp0), (1.into(), inp1)], ); @@ -481,10 +439,7 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let results = Machine::default().run( - TestContext(Arc::new(&hugr)), - [(0.into(), inp0), (1.into(), inp1)], - ); + let results = Machine::default().run(&TestContext, &hugr, [(0.into(), inp0), (1.into(), inp1)]); let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: From 3d4f01675013efa7361b325769b501329002194d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 15:44:59 +0100 Subject: [PATCH 147/281] Massively reduce scope of clippy-allow to inside run_datalog --- hugr-passes/src/dataflow/datalog.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index eb00e1c6a..b1c341dc4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,11 +1,4 @@ //! [ascent] datalog implementation of analysis. -//! Since ascent-(macro-)generated code generates a bunch of warnings, -//! keep code in here to a minimum. -#![allow( - clippy::clone_on_copy, - clippy::unused_enumerate_index, - clippy::collapsible_if -)] use ascent::lattice::{BoundedLattice, Lattice}; use itertools::{zip_eq, Itertools}; @@ -39,6 +32,13 @@ pub(super) fn run_datalog( c: &impl DFContext, hugr: &impl HugrView, ) -> DatalogResults { + // ascent-(macro-)generated code generates a bunch of warnings, + // keep code in here to a minimum. + #![allow( + clippy::clone_on_copy, + clippy::unused_enumerate_index, + clippy::collapsible_if + )] let all_results = ascent::ascent_run! { pub(super) struct AscentProgram; relation node(Node); @@ -145,7 +145,7 @@ pub(super) fn run_datalog( // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- - case_node(cond, _, case), + case_node(cond, _i, case), case_reachable(cond, case), io_node(case, o, IO::Output), in_wire_value(o, o_p, v); @@ -219,6 +219,7 @@ pub(super) fn run_datalog( bb_reachable: all_results.bb_reachable, } } + fn propagate_leaf_op( c: &impl DFContext, hugr: &impl HugrView, From 8bab4d5e099af0159caa0b5f5b1ed46dfa42d95f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 15:51:24 +0100 Subject: [PATCH 148/281] Move ValueRow into own file --- hugr-passes/src/dataflow.rs | 4 +- hugr-passes/src/dataflow/datalog.rs | 95 +----------------------- hugr-passes/src/dataflow/value_row.rs | 101 ++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 94 deletions(-) create mode 100644 hugr-passes/src/dataflow/value_row.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index e24960988..f3e763feb 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,15 +2,15 @@ //! Dataflow analysis of Hugrs. mod datalog; +mod value_row; mod machine; -use hugr_core::ops::constant::CustomConst; pub use machine::{AnalysisResults, Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; -use hugr_core::ops::{ExtensionOp, Value}; +use hugr_core::ops::{constant::CustomConst, ExtensionOp, Value}; use hugr_core::{Hugr, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b1c341dc4..143638ae2 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,15 +1,14 @@ //! [ascent] datalog implementation of analysis. -use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::{zip_eq, Itertools}; -use std::cmp::Ordering; +use ascent::lattice::BoundedLattice; +use itertools::Itertools; use std::hash::Hash; -use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::{OpTrait, OpType}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; +use super::value_row::ValueRow; use super::{AbstractValue, DFContext, PartialValue}; type PV = PartialValue; @@ -273,91 +272,3 @@ fn propagate_leaf_op( o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" } } - -// Wrap a (known-length) row of values into a lattice. -#[derive(PartialEq, Clone, Eq, Hash)] -struct ValueRow(Vec>); - -impl ValueRow { - fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) - } - - fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r - } - - pub fn unpack_first( - &self, - variant: usize, - len: usize, - ) -> Option>> { - let vals = self[0].variant_values(variant, len)?; - Some(vals.into_iter().chain(self.0[1..].to_owned())) - } -} - -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - -impl PartialOrd for ValueRow { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl Lattice for ValueRow { - fn join_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.join_mut(v2); - } - changed - } - - fn meet_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.meet_mut(v2); - } - changed - } -} - -impl IntoIterator for ValueRow { - type Item = PartialValue; - - type IntoIter = > as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl Index for ValueRow -where - Vec>: Index, -{ - type Output = > as Index>::Output; - - fn index(&self, index: Idx) -> &Self::Output { - self.0.index(index) - } -} - -impl IndexMut for ValueRow -where - Vec>: IndexMut, -{ - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { - self.0.index_mut(index) - } -} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs new file mode 100644 index 000000000..ebdbf1b75 --- /dev/null +++ b/hugr-passes/src/dataflow/value_row.rs @@ -0,0 +1,101 @@ +// Wrap a (known-length) row of values into a lattice. + +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; + +use ascent::{lattice::BoundedLattice, Lattice}; +use itertools::zip_eq; + +use super::{AbstractValue, PartialValue}; + +#[derive(PartialEq, Clone, Eq, Hash)] +pub(super) struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + /// The first value in this ValueRow must be a sum; + /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, + /// then appending the rest of the values in this row. + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option>> { + let vals = self[0].variant_values(variant, len)?; + Some(vals.into_iter().chain(self.0[1..].to_owned())) + } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} From 3eccadfbd733f10a7c1db0edb469964b10bbf0c4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 15:59:17 +0100 Subject: [PATCH 149/281] Move Machine into datalog.rs, pub(super) fields in AnalysisResults, rm DatalogResults --- hugr-passes/src/dataflow.rs | 3 +- hugr-passes/src/dataflow/datalog.rs | 67 +++++++++++++---- .../src/dataflow/{machine.rs => results.rs} | 72 +++---------------- 3 files changed, 65 insertions(+), 77 deletions(-) rename hugr-passes/src/dataflow/{machine.rs => results.rs} (65%) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f3e763feb..cce6ea97d 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,10 +2,11 @@ //! Dataflow analysis of Hugrs. mod datalog; +pub use datalog::Machine; mod value_row; mod machine; -pub use machine::{AnalysisResults, Machine, TailLoopTermination}; +pub use machine::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 143638ae2..0421567d0 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -6,10 +6,10 @@ use std::hash::Hash; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::{OpTrait, OpType}; -use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{AbstractValue, DFContext, PartialValue}; +use super::{AbstractValue, AnalysisResults, DFContext, PartialValue}; type PV = PartialValue; @@ -19,18 +19,53 @@ pub enum IO { Output, } -pub(super) struct DatalogResults { - pub in_wire_value: Vec<(Node, IncomingPort, PV)>, - pub out_wire_value: Vec<(Node, OutgoingPort, PV)>, - pub case_reachable: Vec<(Node, Node)>, - pub bb_reachable: Vec<(Node, Node)>, +/// Basic structure for performing an analysis. Usage: +/// 1. Get a new instance via [Self::default()] +/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values +/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via +/// [read_out_wire](AnalysisResults::read_out_wire) +pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); + +/// derived-Default requires the context to be Defaultable, which is unnecessary +impl Default for Machine { + fn default() -> Self { + Self(Default::default()) + } } -pub(super) fn run_datalog( +impl Machine { + /// Provide initial values for some wires. + // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? + pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { + self.0.extend( + h.linked_inputs(wire.node(), wire.source()) + .map(|(n, inp)| (n, inp, value.clone())), + ); + } + + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// given initial values for some of the root node inputs. + /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, + /// but should handle other containers.) + /// The context passed in allows interpretation of leaf operations. + pub fn run( + mut self, + context: &impl DFContext, + hugr: H, + in_values: impl IntoIterator)>, + ) -> AnalysisResults { + let root = hugr.root(); + self.0 + .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + run_datalog(self.0, context, hugr) + } +} + +pub(super) fn run_datalog( in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, c: &impl DFContext, - hugr: &impl HugrView, -) -> DatalogResults { + hugr: H, +) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. #![allow( @@ -84,7 +119,7 @@ pub(super) fn run_datalog( if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(c, hugr, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(c, &hugr, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -211,9 +246,15 @@ pub(super) fn run_datalog( io_node(func, outp, IO::Output), in_wire_value(outp, p, v); }; - DatalogResults { + let out_wire_values = all_results + .out_wire_value + .iter() + .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(); + AnalysisResults { + hugr, + out_wire_values, in_wire_value: all_results.in_wire_value, - out_wire_value: all_results.out_wire_value, case_reachable: all_results.case_reachable, bb_reachable: all_results.bb_reachable, } diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/results.rs similarity index 65% rename from hugr-passes/src/dataflow/machine.rs rename to hugr-passes/src/dataflow/results.rs index f7f97463d..0be8072b0 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,66 +2,15 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::datalog::{run_datalog, DatalogResults}; -use super::{AbstractValue, DFContext, PartialValue}; - -/// Basic structure for performing an analysis. Usage: -/// 1. Get a new instance via [Self::default()] -/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via -/// [read_out_wire](AnalysisResults::read_out_wire) -pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); +use super::{AbstractValue, PartialValue}; /// Results of a dataflow analysis, packaged with context for easy inspection pub struct AnalysisResults { - hugr: H, - results: DatalogResults, - out_wire_values: HashMap>, -} - -/// derived-Default requires the context to be Defaultable, which is unnecessary -impl Default for Machine { - fn default() -> Self { - Self(Default::default()) - } -} - -impl Machine { - /// Provide initial values for some wires. - // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? - pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { - self.0.extend( - h.linked_inputs(wire.node(), wire.source()) - .map(|(n, inp)| (n, inp, value.clone())), - ); - } - - /// Run the analysis (iterate until a lattice fixpoint is reached), - /// given initial values for some of the root node inputs. - /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, - /// but should handle other containers.) - /// The context passed in allows interpretation of leaf operations. - pub fn run( - mut self, - context: &impl DFContext, - hugr: H, - in_values: impl IntoIterator)>, - ) -> AnalysisResults { - let root = hugr.root(); - self.0 - .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let results = run_datalog(self.0, context, &hugr); - let out_wire_values = results - .out_wire_value - .iter() - .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) - .collect(); - AnalysisResults { - hugr, - results, - out_wire_values, - } - } + pub(super) hugr: H, + pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, + pub(super) case_reachable: Vec<(Node, Node)>, + pub(super) bb_reachable: Vec<(Node, Node)>, + pub(super) out_wire_values: HashMap>, } impl AnalysisResults { @@ -79,8 +28,7 @@ impl AnalysisResults { self.hugr.get_optype(node).as_tail_loop()?; let [_, out] = self.hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( - self.results - .in_wire_value + self.in_wire_value .iter() .find_map(|(n, p, v)| (*n == out && p.index() == 0).then_some(v)) .unwrap(), @@ -99,8 +47,7 @@ impl AnalysisResults { let cond = self.hugr.get_parent(case)?; self.hugr.get_optype(cond).as_conditional()?; Some( - self.results - .case_reachable + self.case_reachable .iter() .any(|(cond2, case2)| &cond == cond2 && &case == case2), ) @@ -120,8 +67,7 @@ impl AnalysisResults { return None; }; Some( - self.results - .bb_reachable + self.bb_reachable .iter() .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), ) From 0f4fa522643777fc65cfdfcd3bc0a253e7204272 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:02:03 +0100 Subject: [PATCH 150/281] Move machine.rs to results.rs --- hugr-passes/src/dataflow.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index cce6ea97d..5cd0fa7de 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -5,8 +5,8 @@ mod datalog; pub use datalog::Machine; mod value_row; -mod machine; -pub use machine::{AnalysisResults, TailLoopTermination}; +mod results; +pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; From caa8acaa89bfece8bbf00751a03a3ae9f5140e29 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:08:41 +0100 Subject: [PATCH 151/281] Remove enum IO, replace io_node -> input_child/output_child --- hugr-passes/src/dataflow/datalog.rs | 44 ++++++++++++----------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0421567d0..b8b7f18bf 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -2,7 +2,6 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; -use std::hash::Hash; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::{OpTrait, OpType}; @@ -13,12 +12,6 @@ use super::{AbstractValue, AnalysisResults, DFContext, PartialValue}; type PV = PartialValue; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum IO { - Input, - Output, -} - /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values @@ -79,7 +72,8 @@ pub(super) fn run_datalog( relation in_wire(Node, IncomingPort); relation out_wire(Node, OutgoingPort); relation parent_of_node(Node, Node); - relation io_node(Node, Node, IO); + relation input_child(Node, Node); + relation output_child(Node, Node); lattice out_wire_value(Node, OutgoingPort, PV); lattice in_wire_value(Node, IncomingPort, PV); lattice node_in_value_row(Node, ValueRow); @@ -92,9 +86,8 @@ pub(super) fn run_datalog( parent_of_node(parent, child) <-- node(child), if let Some(parent) = hugr.get_parent(*child); - io_node(parent, child, io) <-- node(parent), - if let Some([i, o]) = hugr.get_io(*parent), - for (child,io) in [(i,IO::Input),(o,IO::Output)]; + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = hugr.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = hugr.get_io(*parent); // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); @@ -127,10 +120,10 @@ pub(super) fn run_datalog( dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - io_node(dfg, i, IO::Input), in_wire_value(dfg, p, v); + input_child(dfg, i), in_wire_value(dfg, p, v); out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - io_node(dfg, o, IO::Output), in_wire_value(o, p, v); + output_child(dfg, o), in_wire_value(o, p, v); // TailLoop @@ -138,23 +131,22 @@ pub(super) fn run_datalog( // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), if hugr.get_optype(*tl).is_tail_loop(), - io_node(tl, i, IO::Input), + input_child(tl, i), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), - io_node(tl, in_n, IO::Input), - io_node(tl, out_n, IO::Output), + input_child(tl, in_n), + output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node - if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), - io_node(tl, out_n, IO::Output), + output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 for (out_p, v) in fields.enumerate(); @@ -171,7 +163,7 @@ pub(super) fn run_datalog( // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), - io_node(case, i_node, IO::Input), + input_child(case, i_node), node_in_value_row(cond, in_row), let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), @@ -181,7 +173,7 @@ pub(super) fn run_datalog( out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- case_node(cond, _i, case), case_reachable(cond, case), - io_node(case, o, IO::Output), + output_child(case, o), in_wire_value(o, o_p, v); relation case_reachable(Node, Node); @@ -200,21 +192,21 @@ pub(super) fn run_datalog( bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), - io_node(pred, pred_out, IO::Output), + output_child(pred, pred_out), in_wire_value(pred_out, IncomingPort::from(0), predicate), for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), io_node(blk, inp, IO::Input); + _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), input_child(blk, inp); _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(), - io_node(entry, i_node, IO::Input), + input_child(entry, i_node), in_wire_value(cfg, p, v); // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself @@ -223,7 +215,7 @@ pub(super) fn run_datalog( bb_reachable(cfg, pred), let df_block = hugr.get_optype(*pred).as_dataflow_block().unwrap(), for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), - io_node(pred, out_n, IO::Output), + output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), @@ -238,12 +230,12 @@ pub(super) fn run_datalog( out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), - io_node(func, inp, IO::Input), + input_child(func, inp), in_wire_value(call, p, v); out_wire_value(call, OutgoingPort::from(p.index()), v) <-- func_call(call, func), - io_node(func, outp, IO::Output), + output_child(func, outp), in_wire_value(outp, p, v); }; let out_wire_values = all_results From e7f61fc9016ff12bd0295d2923850a750bbe9ef7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:10:27 +0100 Subject: [PATCH 152/281] move docs --- hugr-passes/src/dataflow/datalog.rs | 3 +-- hugr-passes/src/dataflow/results.rs | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b8b7f18bf..0d4b6be25 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -15,8 +15,7 @@ type PV = PartialValue; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via -/// [read_out_wire](AnalysisResults::read_out_wire) +/// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); /// derived-Default requires the context to be Defaultable, which is unnecessary diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 0be8072b0..f457ef68c 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -4,7 +4,8 @@ use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, use super::{AbstractValue, PartialValue}; -/// Results of a dataflow analysis, packaged with context for easy inspection +/// Results of a dataflow analysis, packaged with the Hugr for easy inspection. +/// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, From 811802ce94149aace3afa014f9061bcfe7faa40f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:33:39 +0100 Subject: [PATCH 153/281] Remove/inline/dedup dfb_block --- hugr-passes/src/dataflow/datalog.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0d4b6be25..f3f94da3f 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -182,9 +182,7 @@ pub(super) fn run_datalog( // CFG relation cfg_node(Node); - relation dfb_block(Node, Node); cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); - dfb_block(cfg, blk) <-- cfg_node(cfg), for blk in hugr.children(*cfg), if hugr.get_optype(blk).is_dataflow_block(); // Reachability relation bb_reachable(Node, Node); @@ -198,7 +196,10 @@ pub(super) fn run_datalog( // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), input_child(blk, inp); + _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), + for blk in hugr.children(*cfg), + if hugr.get_optype(blk).is_dataflow_block(), + input_child(blk, inp); _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); // Inputs of CFG propagate to entry block @@ -210,9 +211,8 @@ pub(super) fn run_datalog( // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- - dfb_block(cfg, pred), bb_reachable(cfg, pred), - let df_block = hugr.get_optype(*pred).as_dataflow_block().unwrap(), + if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), From 3713ea7e7d5916bc99b71415ecb73c275d7a8469 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:35:27 +0100 Subject: [PATCH 154/281] relation doc --- hugr-passes/src/dataflow/datalog.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index f3f94da3f..d0620e634 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -194,7 +194,8 @@ pub(super) fn run_datalog( for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); - // Where do the values "fed" along a control-flow edge come out? + // Relation: in `CFG` , values fed along a control-flow edge to + // come out of Value outports of . relation _cfg_succ_dest(Node, Node, Node); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), for blk in hugr.children(*cfg), From d3178097d76ff1954507f70da17ecd4aa02da146 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:46:18 +0100 Subject: [PATCH 155/281] datalog docs (each relation), move _cfg_succ_dest --- hugr-passes/src/dataflow/datalog.rs | 53 +++++++++++++++-------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d0620e634..a2f22dbd6 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -67,15 +67,15 @@ pub(super) fn run_datalog( )] let all_results = ascent::ascent_run! { pub(super) struct AscentProgram; - relation node(Node); - relation in_wire(Node, IncomingPort); - relation out_wire(Node, OutgoingPort); - relation parent_of_node(Node, Node); - relation input_child(Node, Node); - relation output_child(Node, Node); - lattice out_wire_value(Node, OutgoingPort, PV); - lattice in_wire_value(Node, IncomingPort, PV); - lattice node_in_value_row(Node, ValueRow); + relation node(Node); // exists in the hugr + relation in_wire(Node, IncomingPort); // has an of `EdgeKind::Value` + relation out_wire(Node, OutgoingPort); // has an of `EdgeKind::Value` + relation parent_of_node(Node, Node); // is parent of + relation input_child(Node, Node); // has 1st child that is its `Input` + relation output_child(Node, Node); // has 2nd child that is its `Output` + lattice out_wire_value(Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(Node, ValueRow); // 's inputs are node(n) <-- for n in hugr.nodes(); @@ -95,13 +95,14 @@ pub(super) fn run_datalog( if let Some((m, op)) = hugr.single_linked_output(*n, *ip), out_wire_value(m, op, v); - // We support prepopulating in_wire_value via in_wire_value_proto. + // Prepopulate in_wire_value from in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), if let Some(sig) = hugr.signature(*n), if sig.input_ports().contains(p); + // Assemble in_value_row from in_value's node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); node_in_value_row(n, ValueRow::single_known(hugr.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); @@ -115,7 +116,7 @@ pub(super) fn run_datalog( for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG - relation dfg_node(Node); + relation dfg_node(Node); // is a `DFG` dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), @@ -151,7 +152,8 @@ pub(super) fn run_datalog( for (out_p, v) in fields.enumerate(); // Conditional - relation conditional_node(Node); + relation conditional_node(Node); // is a `Conditional` + // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); conditional_node(n)<-- node(n), if hugr.get_optype(*n).is_conditional(); @@ -175,16 +177,17 @@ pub(super) fn run_datalog( output_child(case, o), in_wire_value(o, o_p, v); + // In `Conditional` , child `Case` is reachable given our knowledge of predicate relation case_reachable(Node, Node); case_reachable(cond, case) <-- case_node(cond, i, case), in_wire_value(cond, IncomingPort::from(0), v), if v.supports_tag(*i); // CFG - relation cfg_node(Node); + relation cfg_node(Node); // is a `CFG` cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); - // Reachability + // In `CFG` , basic block is reachable given our knowledge of predicates relation bb_reachable(Node, Node); bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), @@ -194,15 +197,6 @@ pub(super) fn run_datalog( for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); - // Relation: in `CFG` , values fed along a control-flow edge to - // come out of Value outports of . - relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), - for blk in hugr.children(*cfg), - if hugr.get_optype(blk).is_dataflow_block(), - input_child(blk, inp); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); - // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), @@ -210,7 +204,16 @@ pub(super) fn run_datalog( input_child(entry, i_node), in_wire_value(cfg, p, v); - // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself + // In `CFG` , values fed along a control-flow edge to + // come out of Value outports of . + relation _cfg_succ_dest(Node, Node, Node); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); + _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), + for blk in hugr.children(*cfg), + if hugr.get_optype(blk).is_dataflow_block(), + input_child(blk, inp); + + // Outputs of each reachable block propagated to successor block or CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- bb_reachable(cfg, pred), if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), @@ -222,7 +225,7 @@ pub(super) fn run_datalog( for (out_p, v) in fields.enumerate(); // Call - relation func_call(Node, Node); + relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), if hugr.get_optype(*call).is_call(), From 69c3270a52a49abc3ab9fe8a4117e4c0ebae9bf5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:46:52 +0100 Subject: [PATCH 156/281] comment, use exactly_one --- hugr-passes/src/dataflow/datalog.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index a2f22dbd6..dc793d872 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -263,16 +263,15 @@ fn propagate_leaf_op( num_outs: usize, ) -> Option> { match hugr.get_optype(n) { - // Handle basics here. I guess (given the current interface) we could allow - // DFContext to handle these but at the least we'd want these impls to be - // easily available for reuse. + // Handle basics here. We could instead leave these to DFContext, + // but at least we'd want these impls to be easily reusable. op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( 0, ins.iter().cloned(), )])), op if op.cast::().is_some() => { let elem_tys = op.cast::().unwrap().0; - let [tup] = ins.iter().collect::>().try_into().unwrap(); + let tup = ins.iter().exactly_one().unwrap(); tup.variant_values(0, elem_tys.len()) .map(ValueRow::from_iter) } From dc56686b305f30e321100efd40fe903ef2766f4d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 12:02:00 +0100 Subject: [PATCH 157/281] Allow to handle LoadFunction --- hugr-passes/src/dataflow.rs | 7 +++++++ hugr-passes/src/dataflow/datalog.rs | 14 ++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 5cd0fa7de..d34c3454b 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -6,6 +6,7 @@ pub use datalog::Machine; mod value_row; mod results; +use hugr_core::types::TypeArg; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; @@ -58,6 +59,12 @@ pub trait DFContext { fn value_from_const_hugr(&self, _node: Node, _fields: &[usize], _h: &Hugr) -> Option { None } + + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node, if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_function(&self, _node: Node, _type_args: &[TypeArg]) -> Option { + None + } } fn traverse_value( diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index dc793d872..da2aefb59 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -295,6 +295,20 @@ fn propagate_leaf_op( c.value_from_const(n, const_val), )) } + OpType::LoadFunction(load_op) => { + assert!(ins.is_empty()); // static edge + let func_node = hugr + .single_linked_output(n, load_op.function_port()) + .unwrap() + .0; + // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself + Some(ValueRow::single_known( + 1, + 0, + c.value_from_function(func_node, &load_op.type_args) + .map_or(PV::Top, PV::Value), + )) + } OpType::ExtensionOp(e) => { // Interpret op. Default is we know nothing about the outputs (they still happen!) let mut outs = vec![PartialValue::Top; num_outs]; From 58f707aadc70250b1b0b0e90bff6b318cca02c75 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 12:28:07 +0100 Subject: [PATCH 158/281] TotalContext: don't get_optype, use the ExtensionOp via DataflowOpTrait --- hugr-passes/src/dataflow/total_context.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 986e3ff50..c11486309 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -1,7 +1,7 @@ use std::hash::Hash; -use hugr_core::ops::ExtensionOp; -use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; +use hugr_core::ops::{DataflowOpTrait, ExtensionOp}; +use hugr_core::{Hugr, IncomingPort, Node, OutgoingPort, PortIndex}; use super::partial_value::{AbstractValue, PartialValue, Sum}; use super::DFContext; @@ -27,14 +27,11 @@ impl> DFContext for T { fn interpret_leaf_op( &self, node: Node, - e: &ExtensionOp, + op: &ExtensionOp, ins: &[PartialValue], outs: &mut [PartialValue], ) { - let op = self.get_optype(node); - let Some(sig) = op.dataflow_signature() else { - return; - }; + let sig = op.signature(); let known_ins = sig .input_types() .iter() @@ -52,7 +49,7 @@ impl> DFContext for T { Some((IncomingPort::from(i), v)) }) .collect::>(); - for (p, v) in self.interpret_leaf_op(node, e, &known_ins) { + for (p, v) in self.interpret_leaf_op(node, op, &known_ins) { outs[p.index()] = v; } } From 05664817641691430e9859de4c9719cdda6bdc90 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 12:35:09 +0100 Subject: [PATCH 159/281] TotalContext: do not require Eq, Hash, etc. --- hugr-passes/src/dataflow/total_context.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index c11486309..d775b76b1 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -1,7 +1,5 @@ -use std::hash::Hash; - use hugr_core::ops::{DataflowOpTrait, ExtensionOp}; -use hugr_core::{Hugr, IncomingPort, Node, OutgoingPort, PortIndex}; +use hugr_core::{IncomingPort, Node, OutgoingPort, PortIndex}; use super::partial_value::{AbstractValue, PartialValue, Sum}; use super::DFContext; @@ -9,7 +7,7 @@ use super::DFContext; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) rather than partially /// (e.g. no [PartialSum]s of more than one variant, no top/bottom) -pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { +pub trait TotalContext { /// Representation of a (single, non-partial) value usable for interpretation type InterpretableVal: From + TryFrom>; From a4977238cc63b92d26315c7b47ffc6d60dcc142a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 12:36:59 +0100 Subject: [PATCH 160/281] HugrValueContext -> ConstFoldContext, empty, no need for Eq/Hash/etc. --- hugr-passes/src/const_fold2.rs | 2 +- hugr-passes/src/const_fold2/context.rs | 63 ++------------------------ 2 files changed, 6 insertions(+), 59 deletions(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 93b772d88..3801ff134 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -4,4 +4,4 @@ // These are pub because this "example" is used for testing the framework. mod context; pub mod value_handle; -pub use context::HugrValueContext; +pub use context::ConstFoldContext; diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 886c8da71..c2b5619e5 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -1,71 +1,18 @@ -use std::hash::{Hash, Hasher}; -use std::ops::Deref; -use std::sync::Arc; - use hugr_core::ops::{ExtensionOp, Value}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; +use hugr_core::{IncomingPort, Node, OutgoingPort}; use super::value_handle::{ValueHandle, ValueKey}; use crate::dataflow::{PartialValue, TotalContext}; -/// A [context](crate::dataflow::DFContext) for doing analysis with [ValueHandle]s. -/// Interprets [LoadConstant](OpType::LoadConstant) nodes, -/// and [ExtensionOp](OpType::ExtensionOp) nodes where the extension does -/// (using [Value]s for extension-op inputs). +/// A [context](crate::dataflow::DFContext) that uses [ValueHandle]s +/// and performs [ExtensionOp::constant_fold] (using [Value]s for extension-op inputs). /// /// Just stores a Hugr (actually any [HugrView]), /// (there is )no state for operation-interpretation. #[derive(Debug)] -pub struct HugrValueContext(Arc); - -impl HugrValueContext { - /// Creates a new instance, given ownership of the [HugrView] - pub fn new(hugr: H) -> Self { - Self(Arc::new(hugr)) - } -} - -// Deriving Clone requires H:HugrView to implement Clone, -// but we don't need that as we only clone the Arc. -impl Clone for HugrValueContext { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -// Any value used in an Ascent program must be hashable. -// However, there should only be one DFContext, so its hash is immaterial. -impl Hash for HugrValueContext { - fn hash(&self, _state: &mut I) {} -} - -impl PartialEq for HugrValueContext { - fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - true - } -} - -impl Eq for HugrValueContext {} - -impl PartialOrd for HugrValueContext { - fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - Some(std::cmp::Ordering::Equal) - } -} - -impl Deref for HugrValueContext { - type Target = Hugr; - - fn deref(&self) -> &Self::Target { - self.0.base_hugr() - } -} +pub struct ConstFoldContext; -impl TotalContext for HugrValueContext { +impl TotalContext for ConstFoldContext { type InterpretableVal = Value; fn interpret_leaf_op( From 33a85923ae9f03bb43fdfe1b25ea8ead69ce9e92 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 14:07:05 +0100 Subject: [PATCH 161/281] doc --- hugr-passes/src/dataflow.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index d34c3454b..8510f8f9b 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -60,8 +60,13 @@ pub trait DFContext { None } - /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node, if possible. + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node (that has been loaded + /// via a [LoadFunction]), if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + /// + /// [FuncDefn]: hugr_core::ops::FuncDefn + /// [FuncDecl]: hugr_core::ops::FuncDecl + /// [LoadFunction]: hugr_core::ops::LoadFunction fn value_from_function(&self, _node: Node, _type_args: &[TypeArg]) -> Option { None } From b153ada208ee3a47fde5b2ebe3ba296c184e0238 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 14:16:33 +0100 Subject: [PATCH 162/281] Separate out ConstLoader --- hugr-passes/src/dataflow.rs | 9 +++++++-- hugr-passes/src/dataflow/test.rs | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 8510f8f9b..c764a2ea0 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -17,7 +17,7 @@ use hugr_core::{Hugr, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext { +pub trait DFContext: ConstLoader { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] @@ -34,7 +34,12 @@ pub trait DFContext { _outs: &mut [PartialValue], ) { } +} +/// Trait for loading [PartialValue]s from constants in a Hugr. The default +/// traverses [Sum](Value::Sum) constants to their non-Sum leaves but represents +/// each leaf as [PartialValue::Top]. +pub trait ConstLoader { /// Produces an abstract value from a constant. The default impl /// traverses the constant [Value] to its leaves ([Value::Extension] and [Value::Function]), /// converts these using [Self::value_from_custom_const] and [Self::value_from_const_hugr], @@ -73,7 +78,7 @@ pub trait DFContext { } fn traverse_value( - s: &(impl DFContext + ?Sized), + s: &(impl ConstLoader + ?Sized), n: Node, fields: &mut Vec, cst: &Value, diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index fe4af1c46..d6721d8ed 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -15,7 +15,7 @@ use hugr_core::{ use hugr_core::{Hugr, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -25,6 +25,7 @@ impl AbstractValue for Void {} struct TestContext; +impl ConstLoader for TestContext {} impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) From 5052ac0268cffcec9c0924edcf28b095314a67be Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 16:10:05 +0100 Subject: [PATCH 163/281] value_from_(custom_const=>opaque), taking &OpaqueValue --- hugr-passes/src/dataflow.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index c764a2ea0..23db00383 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -6,14 +6,15 @@ pub use datalog::Machine; mod value_row; mod results; -use hugr_core::types::TypeArg; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; -use hugr_core::ops::{constant::CustomConst, ExtensionOp, Value}; use hugr_core::{Hugr, Node}; +use hugr_core::ops::{ExtensionOp, Value}; +use hugr_core::ops::constant::OpaqueValue; +use hugr_core::types::TypeArg; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). @@ -48,13 +49,13 @@ pub trait ConstLoader { traverse_value(self, n, &mut Vec::new(), cst) } - /// Produces an abstract value from a [CustomConst], if possible. + /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_custom_const( + fn value_from_opaque( &self, _node: Node, _fields: &[usize], - _cc: &dyn CustomConst, + _val: &OpaqueValue, ) -> Option { None } @@ -94,7 +95,7 @@ fn traverse_value( PartialValue::new_variant(*tag, elems) } Value::Extension { e } => s - .value_from_custom_const(n, fields, e.value()) + .value_from_opaque(n, fields, e) .map(PartialValue::from) .unwrap_or(PartialValue::Top), Value::Function { hugr } => s From d14cc69097847bf0389f310e477f6a70c8fa2a05 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 18:15:46 +0100 Subject: [PATCH 164/281] TotalContext requires ConstLoader --- hugr-passes/src/dataflow/total_context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index d775b76b1..d3c00b0c1 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -2,7 +2,7 @@ use hugr_core::ops::{DataflowOpTrait, ExtensionOp}; use hugr_core::{IncomingPort, Node, OutgoingPort, PortIndex}; use super::partial_value::{AbstractValue, PartialValue, Sum}; -use super::DFContext; +use super::{ConstLoader, DFContext}; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) rather than partially @@ -21,7 +21,7 @@ pub trait TotalContext { ) -> Vec<(OutgoingPort, PartialValue)>; } -impl> DFContext for T { +impl + ConstLoader> DFContext for T { fn interpret_leaf_op( &self, node: Node, From 56b10f586c07351b8e5180fa08ececd1a811e300 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 17:09:31 +0100 Subject: [PATCH 165/281] WIP implement ConstLoader, munge ValueHandle construction --- hugr-core/src/ops/constant.rs | 6 +++ hugr-passes/src/const_fold2/context.rs | 28 ++++++++++- hugr-passes/src/const_fold2/value_handle.rs | 52 ++++++++------------- 3 files changed, 52 insertions(+), 34 deletions(-) diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 5b9309407..b267ba777 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -301,6 +301,12 @@ impl From for OpaqueValue { } } +impl From for Box { + fn from(value: OpaqueValue) -> Self { + value.v + } +} + impl PartialEq for OpaqueValue { fn eq(&self, other: &Self) -> bool { self.value().equal_consts(other.value()) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index c2b5619e5..93eefc8c0 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -1,8 +1,12 @@ +use std::sync::Arc; + +use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::{IncomingPort, Node, OutgoingPort}; +use itertools::Either; -use super::value_handle::{ValueHandle, ValueKey}; -use crate::dataflow::{PartialValue, TotalContext}; +use super::value_handle::{HashedConst, ValueHandle, ValueKey}; +use crate::dataflow::{ConstLoader, PartialValue, TotalContext}; /// A [context](crate::dataflow::DFContext) that uses [ValueHandle]s /// and performs [ExtensionOp::constant_fold] (using [Value]s for extension-op inputs). @@ -12,6 +16,26 @@ use crate::dataflow::{PartialValue, TotalContext}; #[derive(Debug)] pub struct ConstFoldContext; +impl ConstLoader for ConstFoldContext { + fn value_from_opaque( + &self, + node: Node, + fields: &[usize], + val: &OpaqueValue + ) -> Option { + Some(ValueHandle::new_opaque(node, fields, val.clone())) + } + + fn value_from_const_hugr(&self, node: Node, fields: &[usize], h: &hugr_core::Hugr) -> Option { + Some(ValueHandle(fields.iter().fold(ValueKey::Node(node), |k,i|k.field(*i)), Either::Right(Arc::new(h.clone())))) + } + + fn value_from_function(&self, _node: Node, _type_args: &[hugr_core::types::TypeArg]) -> Option { + + } +} + + impl TotalContext for ConstFoldContext { type InterpretableVal = Value; diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 05d57bfaf..c13bb02d4 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -13,7 +13,19 @@ use crate::dataflow::{AbstractValue, PartialValue}; #[derive(Clone, Debug)] pub struct HashedConst { hash: u64, - val: Arc, + pub(super) val: Arc, +} + +impl HashedConst { + pub(super) fn try_new(val: Arc) -> Option { + let mut hasher = DefaultHasher::new(); + val.value().try_hash(&mut hasher).then(|| { + HashedConst { + hash: hasher.finish(), + val + } + }) + } } impl PartialEq for HashedConst { @@ -50,47 +62,23 @@ impl From for ValueKey { } impl ValueKey { - pub fn new(n: Node, k: impl CustomConst) -> Self { - Self::try_new(k).unwrap_or(Self::Node(n)) - } - - pub fn try_new(cst: impl CustomConst) -> Option { - let mut hasher = DefaultHasher::new(); - cst.try_hash(&mut hasher).then(|| { - Self::Const(HashedConst { - hash: hasher.finish(), - val: Arc::new(cst), - }) - }) - } - - fn field(self, i: usize) -> Self { + pub(super) fn field(self, i: usize) -> Self { Self::Field(i, Box::new(self)) } } #[derive(Clone, Debug)] -pub struct ValueHandle(ValueKey, Arc>>); +pub struct ValueHandle(pub(super) ValueKey, pub(super) Either, Arc>); impl ValueHandle { - pub fn new(key: ValueKey, value: Value) -> PartialValue { - match value { - Value::Extension { e } => PartialValue::Value(Self(key, Arc::new(Either::Left(e)))), - Value::Function { hugr } => { - PartialValue::Value(Self(key, Arc::new(Either::Right(hugr)))) - } - Value::Sum(sum) => PartialValue::new_variant( - sum.tag, - sum.values - .into_iter() - .enumerate() - .map(|(i, v)| Self::new(key.clone().field(i), v)), - ), - } + pub fn new_opaque(node: Node, fields: &[usize], val: OpaqueValue) -> Self { + let arc: Arc = Box::::from(val).into(); + let key = HashedConst::try_new(arc.clone()).map_or(ValueKey::Node(node), ValueKey::Const); + Self(fields.iter().fold(key, |k,i| k.field(*i)), Either::Left(arc)) } pub fn get_type(&self) -> Type { - match &*self.1 { + match &self.1 { Either::Left(e) => e.get_type(), Either::Right(bh) => Type::new_function(bh.inner_function_type().unwrap()), } From dff52bd7def427928e294d7c729afaaf340e5760 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 18:13:38 +0100 Subject: [PATCH 166/281] Restructure ValueHandle: no keys for HashedConsts, add functions. --- hugr-passes/src/const_fold2/context.rs | 40 +++-- hugr-passes/src/const_fold2/value_handle.rs | 153 ++++++++++++-------- hugr-passes/src/dataflow.rs | 11 +- 3 files changed, 120 insertions(+), 84 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 93eefc8c0..921ecfe19 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -1,11 +1,8 @@ -use std::sync::Arc; - use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::{IncomingPort, Node, OutgoingPort}; -use itertools::Either; -use super::value_handle::{HashedConst, ValueHandle, ValueKey}; +use super::value_handle::ValueHandle; use crate::dataflow::{ConstLoader, PartialValue, TotalContext}; /// A [context](crate::dataflow::DFContext) that uses [ValueHandle]s @@ -18,24 +15,39 @@ pub struct ConstFoldContext; impl ConstLoader for ConstFoldContext { fn value_from_opaque( - &self, - node: Node, - fields: &[usize], - val: &OpaqueValue + &self, + node: Node, + fields: &[usize], + val: &OpaqueValue, ) -> Option { Some(ValueHandle::new_opaque(node, fields, val.clone())) } - fn value_from_const_hugr(&self, node: Node, fields: &[usize], h: &hugr_core::Hugr) -> Option { - Some(ValueHandle(fields.iter().fold(ValueKey::Node(node), |k,i|k.field(*i)), Either::Right(Arc::new(h.clone())))) + fn value_from_const_hugr( + &self, + node: Node, + fields: &[usize], + h: &hugr_core::Hugr, + ) -> Option { + Some(ValueHandle::new_const_hugr( + node, + fields, + Box::new(h.clone()), + )) } - fn value_from_function(&self, _node: Node, _type_args: &[hugr_core::types::TypeArg]) -> Option { - + fn value_from_function( + &self, + node: Node, + type_args: &[hugr_core::types::TypeArg], + ) -> Option { + Some(ValueHandle::new_function( + node, + type_args.into_iter().cloned(), + )) } } - impl TotalContext for ConstFoldContext { type InterpretableVal = Value; @@ -48,7 +60,7 @@ impl TotalContext for ConstFoldContext { let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); op.constant_fold(&ins).map_or(Vec::new(), |outs| { outs.into_iter() - .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), v))) + .map(|(p, v)| (p, self.value_from_const(n, &v))) .collect() }) } diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index c13bb02d4..f88ce85d5 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -2,35 +2,33 @@ use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1 use std::hash::{Hash, Hasher}; use std::sync::Arc; -use hugr_core::ops::constant::{CustomConst, OpaqueValue}; +use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; -use hugr_core::types::Type; -use hugr_core::{Hugr, HugrView, Node}; +use hugr_core::types::TypeArg; +use hugr_core::{Hugr, Node}; use itertools::Either; -use crate::dataflow::{AbstractValue, PartialValue}; +use crate::dataflow::AbstractValue; #[derive(Clone, Debug)] pub struct HashedConst { hash: u64, - pub(super) val: Arc, + pub(super) val: Arc, } impl HashedConst { - pub(super) fn try_new(val: Arc) -> Option { + pub(super) fn try_new(val: Arc) -> Option { let mut hasher = DefaultHasher::new(); - val.value().try_hash(&mut hasher).then(|| { - HashedConst { - hash: hasher.finish(), - val - } + val.value().try_hash(&mut hasher).then(|| HashedConst { + hash: hasher.finish(), + val, }) } } impl PartialEq for HashedConst { fn eq(&self, other: &Self) -> bool { - self.hash == other.hash && self.val.equal_consts(other.val.as_ref()) + self.hash == other.hash && self.val.value().equal_consts(other.val.value()) } } @@ -43,45 +41,41 @@ impl Hash for HashedConst { } #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum ValueKey { - Field(usize, Box), - Const(HashedConst), +enum NodePart { + Field(usize, Box), Node(Node), } -impl From for ValueKey { - fn from(n: Node) -> Self { - Self::Node(n) - } -} - -impl From for ValueKey { - fn from(value: HashedConst) -> Self { - Self::Const(value) - } -} - -impl ValueKey { - pub(super) fn field(self, i: usize) -> Self { - Self::Field(i, Box::new(self)) +impl NodePart { + fn new(node: Node, fields: &[usize]) -> Self { + fields + .iter() + .fold(Self::Node(node), |k, i| Self::Field(*i, Box::new(k))) } } #[derive(Clone, Debug)] -pub struct ValueHandle(pub(super) ValueKey, pub(super) Either, Arc>); +pub enum ValueHandle { + Hashable(HashedConst), + Unhashable(NodePart, Either, Arc>), + Function(Node, Vec), +} impl ValueHandle { pub fn new_opaque(node: Node, fields: &[usize], val: OpaqueValue) -> Self { - let arc: Arc = Box::::from(val).into(); - let key = HashedConst::try_new(arc.clone()).map_or(ValueKey::Node(node), ValueKey::Const); - Self(fields.iter().fold(key, |k,i| k.field(*i)), Either::Left(arc)) + let arc = Arc::new(val); + HashedConst::try_new(arc.clone()).map_or( + Self::Unhashable(NodePart::new(node, fields), Either::Left(arc)), + Self::Hashable, + ) } - pub fn get_type(&self) -> Type { - match &self.1 { - Either::Left(e) => e.get_type(), - Either::Right(bh) => Type::new_function(bh.inner_function_type().unwrap()), - } + pub fn new_const_hugr(node: Node, fields: &[usize], val: Box) -> Self { + Self::Unhashable(NodePart::new(node, fields), Either::Right(Arc::from(val))) + } + + pub fn new_function(node: Node, type_args: impl IntoIterator) -> Self { + Self::Function(node, type_args.into_iter().collect()) } } @@ -89,16 +83,19 @@ impl AbstractValue for ValueHandle {} impl PartialEq for ValueHandle { fn eq(&self, other: &Self) -> bool { - // If the keys are equal, we return true since the values must have the - // same provenance, and so be equal. If the keys are different but the - // values are equal, we could return true if we didn't impl Eq, but - // since we do impl Eq, the Hash contract prohibits us from having equal - // values with different hashes. - let r = self.0 == other.0; - if r { - debug_assert_eq!(self.get_type(), other.get_type()); + match (self, other) { + (Self::Function(n1, args1), Self::Function(n2, args2)) => n1 == n2 && args1 == args2, + (Self::Hashable(h1), Self::Hashable(h2)) => h1 == h2, + (Self::Unhashable(k1, _), Self::Unhashable(k2, _)) => { + // If the keys are equal, we return true since the values must have the + // same provenance, and so be equal. If the keys are different but the + // values are equal, we could return true if we didn't impl Eq, but + // since we do impl Eq, the Hash contract prohibits us from having equal + // values with different hashes. + k1 == k2 + } + _ => false, } - r } } @@ -106,16 +103,37 @@ impl Eq for ValueHandle {} impl Hash for ValueHandle { fn hash(&self, state: &mut I) { - self.0.hash(state); + match self { + ValueHandle::Hashable(hc) => hc.hash(state), + ValueHandle::Unhashable(key, _) => key.hash(state), + ValueHandle::Function(node, vec) => { + node.hash(state); + vec.hash(state); + } + } } } -impl From for Value { - fn from(value: ValueHandle) -> Self { - match Arc::>::unwrap_or_clone(value.1) { - Either::Left(e) => Value::Extension { e }, - Either::Right(hugr) => Value::Function { hugr }, - } +// Unfortunately we need From for Value to be able to pass +// Value's into interpret_leaf_op. So that probably doesn't make sense... +impl TryFrom for Value { + type Error = String; + fn try_from(value: ValueHandle) -> Result { + Ok(match value { + ValueHandle::Hashable(HashedConst { val, .. }) + | ValueHandle::Unhashable(_, Either::Left(val)) => Value::Extension { + e: Arc::unwrap_or_clone(val), + }, + ValueHandle::Unhashable(_, Either::Right(hugr)) => { + Value::function(Arc::unwrap_or_clone(hugr)).map_err(|e| e.to_string())? + } + ValueHandle::Function(node, _type_args) => { + return Err(format!( + "Function defined externally ({}) cannot be turned into Value", + node + )) + } + }) } } @@ -141,18 +159,29 @@ mod test { fn value_key_eq() { let n = Node::from(portgraph::NodeIndex::new(0)); let n2: Node = portgraph::NodeIndex::new(1).into(); - let k1 = ValueKey::new(n, ConstString::new("foo".to_string())); - let k2 = ValueKey::new(n2, ConstString::new("foo".to_string())); - let k3 = ValueKey::new(n, ConstString::new("bar".to_string())); + let k1 = ValueHandle::new_opaque(n, &[], ConstString::new("foo".to_string()).into()); + let k2 = ValueHandle::new_opaque(n2, &[], ConstString::new("foo".to_string()).into()); + let k3 = ValueHandle::new_opaque(n, &[], ConstString::new("bar".to_string()).into()); - assert_eq!(k1, k2); // Node ignored + assert_eq!(k1, k2); // Node ignored as constant is hashable assert_ne!(k1, k3); - assert_eq!(ValueKey::from(n), ValueKey::from(n)); + // Hashable vs Unhashable is not equal (even with same key): let f = ConstF64::new(std::f64::consts::PI); - assert_eq!(ValueKey::new(n, f.clone()), ValueKey::from(n)); + assert_ne!(ValueHandle::new_opaque(n, &[], f.into()), k1); + assert_ne!(k1, ValueHandle::new_opaque(n, &[], f.into())); + + // Unhashable vals are compared only by key, not content + let f2 = ConstF64::new(std::f64::consts::E); + assert_eq!( + ValueKey::new_opaque(n, &[], f), + ValueKey::new_opaque(n, &[], f2) + ); + assert_ne!( + ValueKey::new_opaque(n, &[], f), + ValueKey::new_opaque(n2, &[], f) + ); - assert_ne!(ValueKey::new(n, f.clone()), ValueKey::new(n2, f)); // Node taken into account let k4 = ValueKey::from(n); let k5 = ValueKey::from(n); let k6: ValueKey = ValueKey::from(n2); diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 188cd4707..7b668b164 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -11,10 +11,10 @@ pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; -use hugr_core::{Hugr, Node}; -use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::ops::constant::OpaqueValue; +use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::types::TypeArg; +use hugr_core::{Hugr, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). @@ -51,12 +51,7 @@ pub trait ConstLoader { /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_opaque( - &self, - _node: Node, - _fields: &[usize], - _val: &OpaqueValue, - ) -> Option { + fn value_from_opaque(&self, _node: Node, _fields: &[usize], _val: &OpaqueValue) -> Option { None } From af651e07ebb26ff3af45210c74546496b2507b4d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 25 Oct 2024 16:20:00 +0100 Subject: [PATCH 167/281] (TEMP / to undo later?) Remove ValueHandle::Function --- hugr-passes/src/const_fold2/context.rs | 11 +---------- hugr-passes/src/const_fold2/value_handle.rs | 7 ------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 921ecfe19..34a1ff84c 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -36,16 +36,7 @@ impl ConstLoader for ConstFoldContext { )) } - fn value_from_function( - &self, - node: Node, - type_args: &[hugr_core::types::TypeArg], - ) -> Option { - Some(ValueHandle::new_function( - node, - type_args.into_iter().cloned(), - )) - } + // Do not handle (Load)Function/value_from_function yet. } impl TotalContext for ConstFoldContext { diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index f88ce85d5..1fe272bf1 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; -use hugr_core::types::TypeArg; use hugr_core::{Hugr, Node}; use itertools::Either; @@ -58,7 +57,6 @@ impl NodePart { pub enum ValueHandle { Hashable(HashedConst), Unhashable(NodePart, Either, Arc>), - Function(Node, Vec), } impl ValueHandle { @@ -73,10 +71,6 @@ impl ValueHandle { pub fn new_const_hugr(node: Node, fields: &[usize], val: Box) -> Self { Self::Unhashable(NodePart::new(node, fields), Either::Right(Arc::from(val))) } - - pub fn new_function(node: Node, type_args: impl IntoIterator) -> Self { - Self::Function(node, type_args.into_iter().collect()) - } } impl AbstractValue for ValueHandle {} @@ -84,7 +78,6 @@ impl AbstractValue for ValueHandle {} impl PartialEq for ValueHandle { fn eq(&self, other: &Self) -> bool { match (self, other) { - (Self::Function(n1, args1), Self::Function(n2, args2)) => n1 == n2 && args1 == args2, (Self::Hashable(h1), Self::Hashable(h2)) => h1 == h2, (Self::Unhashable(k1, _), Self::Unhashable(k2, _)) => { // If the keys are equal, we return true since the values must have the From 29d7e035180f932e5d38d676df0b8c28b80d1da6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 25 Oct 2024 16:20:26 +0100 Subject: [PATCH 168/281] Tidy / fix tests --- hugr-passes/src/const_fold2/context.rs | 45 ++++++++ hugr-passes/src/const_fold2/value_handle.rs | 116 +++++++------------- 2 files changed, 82 insertions(+), 79 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 34a1ff84c..40aec93e5 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -56,3 +56,48 @@ impl TotalContext for ConstFoldContext { }) } } + +#[cfg(test)] +mod test { + use hugr_core::ops::{constant::CustomConst, Value}; + use hugr_core::std_extensions::arithmetic::{float_types::ConstF64, int_types::ConstInt}; + use hugr_core::{types::SumType, Node}; + use itertools::Itertools; + use rstest::rstest; + + use crate::{ + const_fold2::ConstFoldContext, + dataflow::{ConstLoader, PartialValue}, + }; + + #[rstest] + #[case(ConstInt::new_u(4, 2).unwrap(), true)] + #[case(ConstF64::new(std::f64::consts::PI), false)] + fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { + let n = Node::from(portgraph::NodeIndex::new(7)); + let st = SumType::new([vec![k.get_type()], vec![]]); + let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); + let v1 = ConstFoldContext.value_from_const(n, &subject_val); + + let v1_subfield = { + let PartialValue::PartialSum(ps1) = v1 else { + panic!() + }; + ps1.0 + .into_iter() + .exactly_one() + .unwrap() + .1 + .into_iter() + .exactly_one() + .unwrap() + }; + + let v2 = ConstFoldContext.value_from_const(n, &k.into()); + if eq { + assert_eq!(v1_subfield, v2); + } else { + assert_ne!(v1_subfield, v2); + } + } +} diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 1fe272bf1..6c8724c14 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -40,7 +40,7 @@ impl Hash for HashedConst { } #[derive(Clone, Debug, PartialEq, Eq, Hash)] -enum NodePart { +pub enum NodePart { Field(usize, Box), Node(Node), } @@ -99,42 +99,33 @@ impl Hash for ValueHandle { match self { ValueHandle::Hashable(hc) => hc.hash(state), ValueHandle::Unhashable(key, _) => key.hash(state), - ValueHandle::Function(node, vec) => { - node.hash(state); - vec.hash(state); - } } } } // Unfortunately we need From for Value to be able to pass // Value's into interpret_leaf_op. So that probably doesn't make sense... -impl TryFrom for Value { - type Error = String; - fn try_from(value: ValueHandle) -> Result { - Ok(match value { +impl From for Value { + fn from(value: ValueHandle) -> Self { + match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable(_, Either::Left(val)) => Value::Extension { e: Arc::unwrap_or_clone(val), }, ValueHandle::Unhashable(_, Either::Right(hugr)) => { - Value::function(Arc::unwrap_or_clone(hugr)).map_err(|e| e.to_string())? - } - ValueHandle::Function(node, _type_args) => { - return Err(format!( - "Function defined externally ({}) cannot be turned into Value", - node - )) + Value::function(Arc::unwrap_or_clone(hugr)) + .map_err(|e| e.to_string()) + .unwrap() } - }) + } } } #[cfg(test)] mod test { use hugr_core::{ - extension::prelude::ConstString, - ops::constant::CustomConst as _, + builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::{ConstString, USIZE_T}, std_extensions::{ arithmetic::{ float_types::{ConstF64, FLOAT64_TYPE}, @@ -142,9 +133,7 @@ mod test { }, collections::ListValue, }, - types::SumType, }; - use itertools::Itertools; use super::*; @@ -152,44 +141,37 @@ mod test { fn value_key_eq() { let n = Node::from(portgraph::NodeIndex::new(0)); let n2: Node = portgraph::NodeIndex::new(1).into(); - let k1 = ValueHandle::new_opaque(n, &[], ConstString::new("foo".to_string()).into()); - let k2 = ValueHandle::new_opaque(n2, &[], ConstString::new("foo".to_string()).into()); - let k3 = ValueHandle::new_opaque(n, &[], ConstString::new("bar".to_string()).into()); + let h1 = ValueHandle::new_opaque(n, &[], ConstString::new("foo".to_string()).into()); + let h2 = ValueHandle::new_opaque(n2, &[], ConstString::new("foo".to_string()).into()); + let h3 = ValueHandle::new_opaque(n, &[], ConstString::new("bar".to_string()).into()); - assert_eq!(k1, k2); // Node ignored as constant is hashable - assert_ne!(k1, k3); + assert_eq!(h1, h2); // Node ignored as constant is hashable + assert_ne!(h1, h3); // Hashable vs Unhashable is not equal (even with same key): let f = ConstF64::new(std::f64::consts::PI); - assert_ne!(ValueHandle::new_opaque(n, &[], f.into()), k1); - assert_ne!(k1, ValueHandle::new_opaque(n, &[], f.into())); + let h4 = ValueHandle::new_opaque(n, &[], f.clone().into()); + assert_ne!(h4, h1); + assert_ne!(h1, h4); // Unhashable vals are compared only by key, not content let f2 = ConstF64::new(std::f64::consts::E); + assert_eq!(h4, ValueHandle::new_opaque(n, &[], f2.clone().into())); + assert_ne!(h4, ValueHandle::new_opaque(n, &[5], f2.into())); + + let h = Box::new(make_hugr(1)); + let h5 = ValueHandle::new_const_hugr(n, &[], h.clone()); assert_eq!( - ValueKey::new_opaque(n, &[], f), - ValueKey::new_opaque(n, &[], f2) + h5, + ValueHandle::new_const_hugr(n, &[], Box::new(make_hugr(2))) ); - assert_ne!( - ValueKey::new_opaque(n, &[], f), - ValueKey::new_opaque(n2, &[], f) - ); - - let k4 = ValueKey::from(n); - let k5 = ValueKey::from(n); - let k6: ValueKey = ValueKey::from(n2); - - assert_eq!(&k4, &k5); - assert_ne!(&k4, &k6); - - let k7 = k5.clone().field(3); - let k4 = k4.field(3); - - assert_eq!(&k4, &k7); - - let k5 = k5.field(2); + assert_ne!(h5, ValueHandle::new_const_hugr(n2, &[], h)); + } - assert_ne!(&k5, &k7); + fn make_hugr(num_wires: usize) -> Hugr { + let d = DFGBuilder::new(endo_sig(vec![USIZE_T; num_wires])).unwrap(); + let inputs = d.input_wires(); + d.finish_prelude_hugr_with_outputs(inputs).unwrap() } #[test] @@ -199,41 +181,17 @@ mod test { let v3 = ConstF64::new(std::f64::consts::PI); let n = Node::from(portgraph::NodeIndex::new(0)); - let n2: Node = portgraph::NodeIndex::new(1).into(); let lst = ListValue::new(INT_TYPES[0].clone(), [v1.into(), v2.into()]); - assert_eq!(ValueKey::new(n, lst.clone()), ValueKey::new(n2, lst)); + assert_eq!( + ValueHandle::new_opaque(n, &[], lst.clone().into()), + ValueHandle::new_opaque(n, &[1], lst.into()) + ); let lst = ListValue::new(FLOAT64_TYPE, [v3.into()]); assert_ne!( - ValueKey::new(n, lst.clone()), - ValueKey::new(n2, lst.clone()) + ValueHandle::new_opaque(n, &[], lst.clone().into()), + ValueHandle::new_opaque(n, &[3], lst.into()) ); } - - #[test] - fn value_handle_eq() { - let k_i = ConstInt::new_u(4, 2).unwrap(); - let st = SumType::new([vec![k_i.get_type()], vec![]]); - let subject_val = Value::sum(0, [k_i.clone().into()], st).unwrap(); - - let k1 = ValueKey::try_new(ConstString::new("foo".to_string())).unwrap(); - let PartialValue::PartialSum(ps1) = ValueHandle::new(k1.clone(), subject_val.clone()) - else { - panic!() - }; - let (_tag, fields) = ps1.0.into_iter().exactly_one().unwrap(); - let PartialValue::Value(vh1) = fields.into_iter().exactly_one().unwrap() else { - panic!() - }; - - let PartialValue::Value(v2) = ValueHandle::new(k1.clone(), Value::extension(k_i).into()) - else { - panic!() - }; - - // we do not compare the value, just the key - assert_ne!(vh1, v2); - assert_eq!(vh1.1, v2.1); - } } From 8b5f6b009f89910c5444316e9437cf40dcc6416e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 26 Oct 2024 14:30:17 +0100 Subject: [PATCH 169/281] note about keying by outgoingport --- hugr-passes/src/const_fold2/context.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 40aec93e5..c9703ce31 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -51,7 +51,12 @@ impl TotalContext for ConstFoldContext { let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); op.constant_fold(&ins).map_or(Vec::new(), |outs| { outs.into_iter() - .map(|(p, v)| (p, self.value_from_const(n, &v))) + .map(|(p, v)| { + ( + p, + self.value_from_const(n, &v), // Hmmm, should (at least) also key by p + ) + }) .collect() }) } From 9f9a218a049ce5bdbc128fd1b7c1ccf81fb83725 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 25 Oct 2024 19:06:13 +0100 Subject: [PATCH 170/281] (?to revert?) Handle some LoadFunction's by reading subgraphs from Hugr --- hugr-passes/src/const_fold2/context.rs | 35 ++++++++++++++++++-------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index c9703ce31..6d531bb8d 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -1,6 +1,7 @@ -use hugr_core::ops::constant::OpaqueValue; -use hugr_core::ops::{ExtensionOp, Value}; -use hugr_core::{IncomingPort, Node, OutgoingPort}; +use hugr_core::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView}; +use hugr_core::ops::{constant::OpaqueValue, handle::FuncID, ExtensionOp, Value}; +use hugr_core::types::TypeArg; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort}; use super::value_handle::ValueHandle; use crate::dataflow::{ConstLoader, PartialValue, TotalContext}; @@ -11,9 +12,9 @@ use crate::dataflow::{ConstLoader, PartialValue, TotalContext}; /// Just stores a Hugr (actually any [HugrView]), /// (there is )no state for operation-interpretation. #[derive(Debug)] -pub struct ConstFoldContext; +pub struct ConstFoldContext(pub H); -impl ConstLoader for ConstFoldContext { +impl ConstLoader for ConstFoldContext { fn value_from_opaque( &self, node: Node, @@ -36,10 +37,23 @@ impl ConstLoader for ConstFoldContext { )) } - // Do not handle (Load)Function/value_from_function yet. + fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { + if type_args.len() > 0 { + // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) + return None; + }; + // Returning the function body as a value, here, would be sufficient for inlining IndirectCall + // but not for transforming to a direct Call. + let func = DescendantsGraph::>::try_new(&self.0, node).ok()?; + Some(ValueHandle::new_const_hugr( + node, + &[], + Box::new(func.extract_hugr()), + )) + } } -impl TotalContext for ConstFoldContext { +impl TotalContext for ConstFoldContext { type InterpretableVal = Value; fn interpret_leaf_op( @@ -66,7 +80,7 @@ impl TotalContext for ConstFoldContext { mod test { use hugr_core::ops::{constant::CustomConst, Value}; use hugr_core::std_extensions::arithmetic::{float_types::ConstF64, int_types::ConstInt}; - use hugr_core::{types::SumType, Node}; + use hugr_core::{types::SumType, Hugr, Node}; use itertools::Itertools; use rstest::rstest; @@ -82,7 +96,8 @@ mod test { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let v1 = ConstFoldContext.value_from_const(n, &subject_val); + let ctx: ConstFoldContext = ConstFoldContext(Hugr::default()); + let v1 = ctx.value_from_const(n, &subject_val); let v1_subfield = { let PartialValue::PartialSum(ps1) = v1 else { @@ -98,7 +113,7 @@ mod test { .unwrap() }; - let v2 = ConstFoldContext.value_from_const(n, &k.into()); + let v2 = ctx.value_from_const(n, &k.into()); if eq { assert_eq!(v1_subfield, v2); } else { From ad0c6f24e31fbe9413c98251244a033025600002 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 14:54:33 +0000 Subject: [PATCH 171/281] Recombine DFContext with Hugr i.e. reinstate Deref constraint --- hugr-passes/src/dataflow.rs | 13 ++-- hugr-passes/src/dataflow/datalog.rs | 99 ++++++++++++++--------------- hugr-passes/src/dataflow/results.rs | 8 +-- hugr-passes/src/dataflow/test.rs | 31 +++++---- 4 files changed, 74 insertions(+), 77 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 23db00383..2ac927552 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -11,14 +11,14 @@ pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; -use hugr_core::{Hugr, Node}; -use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::ops::constant::OpaqueValue; +use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::types::TypeArg; +use hugr_core::{Hugr, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader { +pub trait DFContext: ConstLoader + std::ops::Deref { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] @@ -51,12 +51,7 @@ pub trait ConstLoader { /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_opaque( - &self, - _node: Node, - _fields: &[usize], - _val: &OpaqueValue, - ) -> Option { + fn value_from_opaque(&self, _node: Node, _fields: &[usize], _val: &OpaqueValue) -> Option { None } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index da2aefb59..c684ac9cd 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -40,24 +40,22 @@ impl Machine { /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. - pub fn run( + pub fn run>( mut self, - context: &impl DFContext, - hugr: H, + context: C, in_values: impl IntoIterator)>, - ) -> AnalysisResults { - let root = hugr.root(); + ) -> AnalysisResults { + let root = context.root(); self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - run_datalog(self.0, context, hugr) + run_datalog(context, self.0) } } -pub(super) fn run_datalog( +pub(super) fn run_datalog>( + ctx: C, in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, - c: &impl DFContext, - hugr: H, -) -> AnalysisResults { +) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. #![allow( @@ -77,47 +75,47 @@ pub(super) fn run_datalog( lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value lattice node_in_value_row(Node, ValueRow); // 's inputs are - node(n) <-- for n in hugr.nodes(); + node(n) <-- for n in ctx.nodes(); - in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only - out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise) + in_wire(n, p) <-- node(n), for (p,_) in ctx.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in ctx.out_value_types(*n); // (and likewise) parent_of_node(parent, child) <-- - node(child), if let Some(parent) = hugr.get_parent(*child); + node(child), if let Some(parent) = ctx.get_parent(*child); - input_child(parent, input) <-- node(parent), if let Some([input, _output]) = hugr.get_io(*parent); - output_child(parent, output) <-- node(parent), if let Some([_input, output]) = hugr.get_io(*parent); + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = ctx.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = ctx.get_io(*parent); // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); in_wire_value(n, ip, v) <-- in_wire(n, ip), - if let Some((m, op)) = hugr.single_linked_output(*n, *ip), + if let Some((m, op)) = ctx.single_linked_output(*n, *ip), out_wire_value(m, op, v); // Prepopulate in_wire_value from in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), - if let Some(sig) = hugr.signature(*n), + if let Some(sig) = ctx.signature(*n), if sig.input_ports().contains(p); // Assemble in_value_row from in_value's - node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); - node_in_value_row(n, ValueRow::single_known(hugr.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); + node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); out_wire_value(n, p, v) <-- node(n), - let op_t = hugr.get_optype(*n), + let op_t = ctx.get_optype(*n), if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(c, &hugr, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(&ctx, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG relation dfg_node(Node); // is a `DFG` - dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); + dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), input_child(dfg, i), in_wire_value(dfg, p, v); @@ -130,13 +128,13 @@ pub(super) fn run_datalog( // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), - if hugr.get_optype(*tl).is_tail_loop(), + if ctx.get_optype(*tl).is_tail_loop(), input_child(tl, i), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), input_child(tl, in_n), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node @@ -145,7 +143,7 @@ pub(super) fn run_datalog( // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 @@ -156,17 +154,17 @@ pub(super) fn run_datalog( // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); - conditional_node(n)<-- node(n), if hugr.get_optype(*n).is_conditional(); + conditional_node(n)<-- node(n), if ctx.get_optype(*n).is_conditional(); case_node(cond, i, case) <-- conditional_node(cond), - for (i, case) in hugr.children(*cond).enumerate(), - if hugr.get_optype(case).is_case(); + for (i, case) in ctx.children(*cond).enumerate(), + if ctx.get_optype(case).is_case(); // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), input_child(case, i_node), node_in_value_row(cond, in_row), - let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), + let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -185,39 +183,39 @@ pub(super) fn run_datalog( // CFG relation cfg_node(Node); // is a `CFG` - cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); + cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); // In `CFG` , basic block is reachable given our knowledge of predicates relation bb_reachable(Node, Node); - bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), output_child(pred, pred_out), in_wire_value(pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), + for (tag, bb) in ctx.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), - if let Some(entry) = hugr.children(*cfg).next(), + if let Some(entry) = ctx.children(*cfg).next(), input_child(entry, i_node), in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of . relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), - for blk in hugr.children(*cfg), - if hugr.get_optype(blk).is_dataflow_block(), + for blk in ctx.children(*cfg), + if ctx.get_optype(blk).is_dataflow_block(), input_child(blk, inp); // Outputs of each reachable block propagated to successor block or CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- bb_reachable(cfg, pred), - if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), - for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), + if let Some(df_block) = ctx.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in ctx.output_neighbours(*pred).enumerate(), output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), @@ -228,8 +226,8 @@ pub(super) fn run_datalog( relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), - if hugr.get_optype(*call).is_call(), - if let Some(func_defn) = hugr.static_source(*call); + if ctx.get_optype(*call).is_call(), + if let Some(func_defn) = ctx.static_source(*call); out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), @@ -247,7 +245,7 @@ pub(super) fn run_datalog( .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - hugr, + hugr: ctx, out_wire_values, in_wire_value: all_results.in_wire_value, case_reachable: all_results.case_reachable, @@ -256,13 +254,12 @@ pub(super) fn run_datalog( } fn propagate_leaf_op( - c: &impl DFContext, - hugr: &impl HugrView, + ctx: &impl DFContext, n: Node, ins: &[PV], num_outs: usize, ) -> Option> { - match hugr.get_optype(n) { + match ctx.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( @@ -284,20 +281,20 @@ fn propagate_leaf_op( OpType::Const(_) => None, // handled by LoadConstant: OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant - let const_node = hugr + let const_node = ctx .single_linked_output(n, load_op.constant_port()) .unwrap() .0; - let const_val = hugr.get_optype(const_node).as_const().unwrap().value(); + let const_val = ctx.get_optype(const_node).as_const().unwrap().value(); Some(ValueRow::single_known( 1, 0, - c.value_from_const(n, const_val), + ctx.value_from_const(n, const_val), )) } OpType::LoadFunction(load_op) => { assert!(ins.is_empty()); // static edge - let func_node = hugr + let func_node = ctx .single_linked_output(n, load_op.function_port()) .unwrap() .0; @@ -305,7 +302,7 @@ fn propagate_leaf_op( Some(ValueRow::single_known( 1, 0, - c.value_from_function(func_node, &load_op.type_args) + ctx.value_from_function(func_node, &load_op.type_args) .map_or(PV::Top, PV::Value), )) } @@ -315,7 +312,7 @@ fn propagate_leaf_op( // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues // are not necessarily convertible to Value. - c.interpret_leaf_op(n, e, ins, &mut outs[..]); + ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); Some(ValueRow::from_iter(outs)) } o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index f457ef68c..3e37e2ce9 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,11 +2,11 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, PartialValue}; +use super::{AbstractValue, DFContext, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults { +pub struct AnalysisResults { pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, @@ -14,7 +14,7 @@ pub struct AnalysisResults { pub(super) out_wire_values: HashMap>, } -impl AnalysisResults { +impl> AnalysisResults { /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -75,7 +75,7 @@ impl AnalysisResults { } } -impl AnalysisResults +impl> AnalysisResults where Value: From, { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index d6721d8ed..e00827fa8 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -23,10 +23,16 @@ enum Void {} impl AbstractValue for Void {} -struct TestContext; +struct TestContext(H); -impl ConstLoader for TestContext {} -impl DFContext for TestContext {} +impl std::ops::Deref for TestContext { + type Target = Hugr; + fn deref(&self) -> &Hugr { + self.0.base_hugr() + } +} +impl ConstLoader for TestContext {} +impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { @@ -55,7 +61,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(hugr), []); let x = results.try_read_wire_value(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -71,7 +77,7 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(hugr), []); let o1_r = results.try_read_wire_value(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -94,7 +100,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(hugr), []); let o_r = results.try_read_wire_value(tl_o).unwrap(); assert_eq!(o_r, r_v); @@ -123,7 +129,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(&hugr), []); let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -161,7 +167,7 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(&hugr), []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -227,7 +233,7 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(&hugr), []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -279,7 +285,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - let results = Machine::default().run(&TestContext, &hugr, [(0.into(), arg_pv)]); + let results = Machine::default().run(TestContext(hugr), [(0.into(), arg_pv)]); let cond_r1 = results.try_read_wire_value(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); @@ -403,8 +409,7 @@ fn test_cfg( ) { let root = xor_and_cfg.root(); let results = Machine::default().run( - &TestContext, - &xor_and_cfg, + TestContext(xor_and_cfg), [(0.into(), inp0), (1.into(), inp1)], ); @@ -440,7 +445,7 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let results = Machine::default().run(&TestContext, &hugr, [(0.into(), inp0), (1.into(), inp1)]); + let results = Machine::default().run(TestContext(&hugr), [(0.into(), inp0), (1.into(), inp1)]); let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: From 16a18f4bfe048350116f0e5e5a003567faa4d072 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 14:59:44 +0000 Subject: [PATCH 172/281] Replace Deref with HugrView, trivially obtainable by implementing AsRef --- hugr-passes/src/dataflow.rs | 4 ++-- hugr-passes/src/dataflow/results.rs | 8 ++++---- hugr-passes/src/dataflow/test.rs | 5 ++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 2ac927552..c960fb7de 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -14,11 +14,11 @@ pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::types::TypeArg; -use hugr_core::{Hugr, Node}; +use hugr_core::{Hugr, HugrView, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader + std::ops::Deref { +pub trait DFContext: ConstLoader + HugrView { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 3e37e2ce9..f457ef68c 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,11 +2,11 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, DFContext, PartialValue}; +use super::{AbstractValue, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults { +pub struct AnalysisResults { pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, @@ -14,7 +14,7 @@ pub struct AnalysisResults { pub(super) out_wire_values: HashMap>, } -impl> AnalysisResults { +impl AnalysisResults { /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -75,7 +75,7 @@ impl> AnalysisResults { } } -impl> AnalysisResults +impl AnalysisResults where Value: From, { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e00827fa8..e73cc8e4f 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -25,9 +25,8 @@ impl AbstractValue for Void {} struct TestContext(H); -impl std::ops::Deref for TestContext { - type Target = Hugr; - fn deref(&self) -> &Hugr { +impl AsRef for TestContext { + fn as_ref(&self) -> &Hugr { self.0.base_hugr() } } From 8058dab97c0508da6b54b20c93d963d2f0d1b105 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 15:25:16 +0000 Subject: [PATCH 173/281] Implement pass via MutRefCell (TODO: Combine with ConstFoldContext) --- hugr-core/src/hugr/views.rs | 22 ++-- hugr-passes/src/const_fold2.rs | 135 ++++++++++++++++++++++ hugr-passes/src/const_fold2/context.rs | 4 +- hugr-passes/src/dataflow/results.rs | 2 +- hugr-passes/src/dataflow/total_context.rs | 2 +- 5 files changed, 153 insertions(+), 12 deletions(-) diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 6a52f33f0..fc90df10d 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -528,14 +528,6 @@ impl RootTagged for Hugr { type RootHandle = Node; } -impl RootTagged for &Hugr { - type RootHandle = Node; -} - -impl RootTagged for &mut Hugr { - type RootHandle = Node; -} - // Explicit implementation to avoid cloning the Hugr. impl ExtractHugr for Hugr { fn extract_hugr(self) -> Hugr { @@ -555,6 +547,20 @@ impl ExtractHugr for &mut Hugr { } } +impl<'a, H: RootTagged> RootTagged for &'a H +where + &'a H: HugrView, +{ + type RootHandle = H::RootHandle; +} + +impl<'a, H: RootTagged> RootTagged for &'a mut H +where + &'a mut H: HugrView, +{ + type RootHandle = H::RootHandle; +} + impl> HugrView for T { /// An Iterator over the nodes in a Hugr(View) type Nodes<'a> = MapInto, Node> where Self: 'a; diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 3801ff134..321f58fd2 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -4,4 +4,139 @@ // These are pub because this "example" is used for testing the framework. mod context; pub mod value_handle; +use std::collections::{HashSet, VecDeque}; + pub use context::ConstFoldContext; +use hugr_core::{ + extension::ExtensionRegistry, + hugr::hugrmut::HugrMut, + ops::{Const, LoadConstant}, + types::EdgeKind, + HugrView, IncomingPort, Node, OutgoingPort, Wire, +}; +use value_handle::ValueHandle; + +use crate::{ + dataflow::{AnalysisResults, Machine, TailLoopTermination}, + validation::{ValidatePassError, ValidationLevel}, +}; + +pub struct ConstFoldPass { + validation: ValidationLevel, + /// If true, allow to skip evaluating loops (whose results are not needed) even if + /// we are not sure they will terminate. (If they definitely terminate then fair game.) + pub allow_skip_loops: bool, +} + +struct MutRefCell<'a, H>(&'a mut H); + +impl<'a, T: HugrView> AsRef for MutRefCell<'a, T> { + fn as_ref(&self) -> &hugr_core::Hugr { + self.0.base_hugr() + } +} + +impl ConstFoldPass { + /// Run the Constant Folding pass. + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + let ctx = ConstFoldContext(MutRefCell(hugr)); + let results = Machine::default().run(ctx, []); + let mut keep_nodes = HashSet::new(); + self.find_needed_nodes(&results, results.hugr.root(), &mut keep_nodes); + + let remove_nodes = results + .hugr + .nodes() + .filter(|n| !keep_nodes.contains(n)) + .collect::>(); + for n in keep_nodes { + // Every input either (a) is in keep_nodes, or (b) has a known value. Break all wires (b). + for inport in results.hugr.node_inputs(n) { + if matches!( + results.hugr.get_optype(n).port_kind(inport).unwrap(), + EdgeKind::Value(_) + ) { + let (src, outp) = results.hugr.single_linked_output(n, inport).unwrap(); + if let Ok(v) = results.try_read_wire_value(Wire::new(src, outp)) { + let parent = results.hugr.get_parent(n).unwrap(); + let datatype = v.get_type(); + // We could try hash-consing identical Consts, but not ATM + let hugr_mut = &mut *results.hugr.0 .0; + let cst = hugr_mut.add_node_with_parent(parent, Const::new(v)); + let lcst = hugr_mut.add_node_with_parent(parent, LoadConstant { datatype }); + hugr_mut.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); + } + } + } + } + for n in remove_nodes { + hugr.remove_node(n); + } + Ok(()) + } + + pub fn run( + &self, + hugr: &mut H, + reg: &ExtensionRegistry, + ) -> Result<(), ValidatePassError> { + self.validation + .run_validated_pass(hugr, reg, |hugr: &mut H, _| self.run_no_validate(hugr)) + } + + fn find_needed_nodes( + &self, + results: &AnalysisResults, + container: Node, + needed: &mut HashSet, + ) { + let h = &results.hugr; + if h.get_optype(container).is_cfg() { + for bb in h.children(container) { + if results.bb_reachable(bb).unwrap() + && needed.insert(bb) + && h.get_optype(bb).is_dataflow_block() + { + self.find_needed_nodes(results, bb, needed); + } + } + } else { + // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. + let [_inp, outp] = h.get_io(container).unwrap(); + let mut q = VecDeque::new(); + q.push_back(outp); + // Add on anything that might not terminate. We might also allow a custom predicate for extension ops? + for n in h.children(container) { + if h.get_optype(n).is_cfg() + || (!self.allow_skip_loops + && h.get_optype(n).is_tail_loop() + && results.tail_loop_terminates(n).unwrap() + != TailLoopTermination::NeverContinues) + { + q.push_back(n); + } + } + while let Some(n) = q.pop_front() { + if !needed.insert(n) { + continue; + } + for (src, op) in h.all_linked_outputs(n) { + let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { + EdgeKind::Value(_) => { + results.try_read_wire_value(Wire::new(src, op)).is_err() + } + EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, + EdgeKind::ControlFlow => panic!(), + _ => true, // needed for non-exhaustive; not knowing what it is, assume the worst + }; + if needs_predecessor { + q.push_back(src); + } + } + if h.get_optype(n).is_container() { + self.find_needed_nodes(results, container, needed); + } + } + } + } +} diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 22850c55d..5f038c796 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -12,9 +12,9 @@ use crate::dataflow::{ConstLoader, PartialValue, TotalContext}; /// Just stores a Hugr (actually any [HugrView]), /// (there is )no state for operation-interpretation. #[derive(Debug)] -pub struct ConstFoldContext(pub H); +pub struct ConstFoldContext(pub(super) H); -impl AsRef for ConstFoldContext { +impl AsRef for ConstFoldContext { fn as_ref(&self) -> &Hugr { self.0.base_hugr() } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index f457ef68c..a8800f755 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -7,7 +7,7 @@ use super::{AbstractValue, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { - pub(super) hugr: H, + pub hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, pub(super) bb_reachable: Vec<(Node, Node)>, diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 33a96ca0d..f325b046a 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -7,7 +7,7 @@ use super::{ConstLoader, DFContext}; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) rather than partially /// (e.g. no [PartialSum]s of more than one variant, no top/bottom) -pub trait TotalContext { +pub trait TotalContext: ConstLoader { /// Representation of a (single, non-partial) value usable for interpretation type InterpretableVal: From + TryFrom>; From a766ed981d22ab8a60c931450ab321f639989ce8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 15:46:19 +0000 Subject: [PATCH 174/281] Combine ConstFoldContext with MutRefCell, context.rs -> test.rs --- hugr-passes/src/const_fold2.rs | 87 ++++++++++++++--- hugr-passes/src/const_fold2/context.rs | 129 ------------------------- hugr-passes/src/const_fold2/test.rs | 45 +++++++++ 3 files changed, 121 insertions(+), 140 deletions(-) delete mode 100644 hugr-passes/src/const_fold2/context.rs create mode 100644 hugr-passes/src/const_fold2/test.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 321f58fd2..1e57a8d2b 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -2,22 +2,26 @@ //! to perform constant-folding. // These are pub because this "example" is used for testing the framework. -mod context; +mod test; pub mod value_handle; use std::collections::{HashSet, VecDeque}; -pub use context::ConstFoldContext; use hugr_core::{ extension::ExtensionRegistry, - hugr::hugrmut::HugrMut, - ops::{Const, LoadConstant}, - types::EdgeKind, + hugr::{ + hugrmut::HugrMut, + views::{DescendantsGraph, ExtractHugr, HierarchyView}, + }, + ops::{constant::OpaqueValue, handle::FuncID, Const, ExtensionOp, LoadConstant, Value}, + types::{EdgeKind, TypeArg}, HugrView, IncomingPort, Node, OutgoingPort, Wire, }; use value_handle::ValueHandle; use crate::{ - dataflow::{AnalysisResults, Machine, TailLoopTermination}, + dataflow::{ + AnalysisResults, ConstLoader, Machine, PartialValue, TailLoopTermination, TotalContext, + }, validation::{ValidatePassError, ValidationLevel}, }; @@ -28,19 +32,80 @@ pub struct ConstFoldPass { pub allow_skip_loops: bool, } -struct MutRefCell<'a, H>(&'a mut H); +struct ConstFoldContext<'a, H>(&'a mut H); -impl<'a, T: HugrView> AsRef for MutRefCell<'a, T> { +impl<'a, T: HugrView> AsRef for ConstFoldContext<'a, T> { fn as_ref(&self) -> &hugr_core::Hugr { self.0.base_hugr() } } +impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { + fn value_from_opaque( + &self, + node: Node, + fields: &[usize], + val: &OpaqueValue, + ) -> Option { + Some(ValueHandle::new_opaque(node, fields, val.clone())) + } + + fn value_from_const_hugr( + &self, + node: Node, + fields: &[usize], + h: &hugr_core::Hugr, + ) -> Option { + Some(ValueHandle::new_const_hugr( + node, + fields, + Box::new(h.clone()), + )) + } + + fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { + if type_args.len() > 0 { + // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) + return None; + }; + // Returning the function body as a value, here, would be sufficient for inlining IndirectCall + // but not for transforming to a direct Call. + let func = DescendantsGraph::>::try_new(self, node).ok()?; + Some(ValueHandle::new_const_hugr( + node, + &[], + Box::new(func.extract_hugr()), + )) + } +} + +impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { + type InterpretableVal = Value; + + fn interpret_leaf_op( + &self, + n: Node, + op: &ExtensionOp, + ins: &[(IncomingPort, Value)], + ) -> Vec<(OutgoingPort, PartialValue)> { + let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); + op.constant_fold(&ins).map_or(Vec::new(), |outs| { + outs.into_iter() + .map(|(p, v)| { + ( + p, + self.value_from_const(n, &v), // Hmmm, should (at least) also key by p + ) + }) + .collect() + }) + } +} + impl ConstFoldPass { /// Run the Constant Folding pass. fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - let ctx = ConstFoldContext(MutRefCell(hugr)); - let results = Machine::default().run(ctx, []); + let results = Machine::default().run(ConstFoldContext(hugr), []); let mut keep_nodes = HashSet::new(); self.find_needed_nodes(&results, results.hugr.root(), &mut keep_nodes); @@ -61,7 +126,7 @@ impl ConstFoldPass { let parent = results.hugr.get_parent(n).unwrap(); let datatype = v.get_type(); // We could try hash-consing identical Consts, but not ATM - let hugr_mut = &mut *results.hugr.0 .0; + let hugr_mut = &mut *results.hugr.0; let cst = hugr_mut.add_node_with_parent(parent, Const::new(v)); let lcst = hugr_mut.add_node_with_parent(parent, LoadConstant { datatype }); hugr_mut.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs deleted file mode 100644 index 5f038c796..000000000 --- a/hugr-passes/src/const_fold2/context.rs +++ /dev/null @@ -1,129 +0,0 @@ -use hugr_core::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView}; -use hugr_core::ops::{constant::OpaqueValue, handle::FuncID, ExtensionOp, Value}; -use hugr_core::types::TypeArg; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; - -use super::value_handle::ValueHandle; -use crate::dataflow::{ConstLoader, PartialValue, TotalContext}; - -/// A [context](crate::dataflow::DFContext) that uses [ValueHandle]s -/// and performs [ExtensionOp::constant_fold] (using [Value]s for extension-op inputs). -/// -/// Just stores a Hugr (actually any [HugrView]), -/// (there is )no state for operation-interpretation. -#[derive(Debug)] -pub struct ConstFoldContext(pub(super) H); - -impl AsRef for ConstFoldContext { - fn as_ref(&self) -> &Hugr { - self.0.base_hugr() - } -} - -impl ConstLoader for ConstFoldContext { - fn value_from_opaque( - &self, - node: Node, - fields: &[usize], - val: &OpaqueValue, - ) -> Option { - Some(ValueHandle::new_opaque(node, fields, val.clone())) - } - - fn value_from_const_hugr( - &self, - node: Node, - fields: &[usize], - h: &hugr_core::Hugr, - ) -> Option { - Some(ValueHandle::new_const_hugr( - node, - fields, - Box::new(h.clone()), - )) - } - - fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { - if type_args.len() > 0 { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&self.0, node).ok()?; - Some(ValueHandle::new_const_hugr( - node, - &[], - Box::new(func.extract_hugr()), - )) - } -} - -impl TotalContext for ConstFoldContext { - type InterpretableVal = Value; - - fn interpret_leaf_op( - &self, - n: Node, - op: &ExtensionOp, - ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort, PartialValue)> { - let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); - op.constant_fold(&ins).map_or(Vec::new(), |outs| { - outs.into_iter() - .map(|(p, v)| { - ( - p, - self.value_from_const(n, &v), // Hmmm, should (at least) also key by p - ) - }) - .collect() - }) - } -} - -#[cfg(test)] -mod test { - use hugr_core::ops::{constant::CustomConst, Value}; - use hugr_core::std_extensions::arithmetic::{float_types::ConstF64, int_types::ConstInt}; - use hugr_core::{types::SumType, Hugr, Node}; - use itertools::Itertools; - use rstest::rstest; - - use crate::{ - const_fold2::ConstFoldContext, - dataflow::{ConstLoader, PartialValue}, - }; - - #[rstest] - #[case(ConstInt::new_u(4, 2).unwrap(), true)] - #[case(ConstF64::new(std::f64::consts::PI), false)] - fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { - let n = Node::from(portgraph::NodeIndex::new(7)); - let st = SumType::new([vec![k.get_type()], vec![]]); - let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let ctx: ConstFoldContext = ConstFoldContext(Hugr::default()); - let v1 = ctx.value_from_const(n, &subject_val); - - let v1_subfield = { - let PartialValue::PartialSum(ps1) = v1 else { - panic!() - }; - ps1.0 - .into_iter() - .exactly_one() - .unwrap() - .1 - .into_iter() - .exactly_one() - .unwrap() - }; - - let v2 = ctx.value_from_const(n, &k.into()); - if eq { - assert_eq!(v1_subfield, v2); - } else { - assert_ne!(v1_subfield, v2); - } - } -} diff --git a/hugr-passes/src/const_fold2/test.rs b/hugr-passes/src/const_fold2/test.rs new file mode 100644 index 000000000..f329bf075 --- /dev/null +++ b/hugr-passes/src/const_fold2/test.rs @@ -0,0 +1,45 @@ +#![cfg(test)] + +use hugr_core::ops::{constant::CustomConst, Value}; +use hugr_core::std_extensions::arithmetic::{float_types::ConstF64, int_types::ConstInt}; +use hugr_core::{types::SumType, Hugr, Node}; +use itertools::Itertools; +use rstest::rstest; + +use crate::{ + const_fold2::ConstFoldContext, + dataflow::{ConstLoader, PartialValue}, +}; + +#[rstest] +#[case(ConstInt::new_u(4, 2).unwrap(), true)] +#[case(ConstF64::new(std::f64::consts::PI), false)] +fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { + let n = Node::from(portgraph::NodeIndex::new(7)); + let st = SumType::new([vec![k.get_type()], vec![]]); + let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); + let mut temp = Hugr::default(); + let ctx: ConstFoldContext = ConstFoldContext(&mut temp); + let v1 = ctx.value_from_const(n, &subject_val); + + let v1_subfield = { + let PartialValue::PartialSum(ps1) = v1 else { + panic!() + }; + ps1.0 + .into_iter() + .exactly_one() + .unwrap() + .1 + .into_iter() + .exactly_one() + .unwrap() + }; + + let v2 = ctx.value_from_const(n, &k.into()); + if eq { + assert_eq!(v1_subfield, v2); + } else { + assert_ne!(v1_subfield, v2); + } +} From a0f2b2ca1a4418b8b66f2208bdbe09d7adcc9d95 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 17:33:18 +0000 Subject: [PATCH 175/281] ValueRow Debug; ops default to PartialValue::Top less aggressively --- hugr-passes/src/dataflow/datalog.rs | 12 ++++++++++-- hugr-passes/src/dataflow/value_row.rs | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index c684ac9cd..2cc535d89 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -307,8 +307,16 @@ fn propagate_leaf_op( )) } OpType::ExtensionOp(e) => { - // Interpret op. Default is we know nothing about the outputs (they still happen!) - let mut outs = vec![PartialValue::Top; num_outs]; + // Interpret op. + let init = if ins.iter().contains(&PartialValue::Bottom) { + // So far we think one or more inputs can't happen. + // So, don't pollute outputs with Top, and wait for better knowledge of inputs. + PartialValue::Bottom + } else { + // If we can't figure out anything about the outputs, assume nothing (they still happen!) + PartialValue::Top + }; + let mut outs = vec![init; num_outs]; // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues // are not necessarily convertible to Value. diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index ebdbf1b75..0d8bc15a6 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -10,7 +10,7 @@ use itertools::zip_eq; use super::{AbstractValue, PartialValue}; -#[derive(PartialEq, Clone, Eq, Hash)] +#[derive(PartialEq, Clone, Debug, Eq, Hash)] pub(super) struct ValueRow(Vec>); impl ValueRow { From 53447fe3b5f2384fed724321bb97efcda0b0c738 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 15:49:24 +0000 Subject: [PATCH 176/281] Reorder, make look more like original const_fold --- hugr-passes/src/const_fold2.rs | 154 +++++++++++++++------------- hugr-passes/src/const_fold2/test.rs | 8 +- 2 files changed, 86 insertions(+), 76 deletions(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 1e57a8d2b..49d7bf457 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -2,7 +2,6 @@ //! to perform constant-folding. // These are pub because this "example" is used for testing the framework. -mod test; pub mod value_handle; use std::collections::{HashSet, VecDeque}; @@ -25,6 +24,8 @@ use crate::{ validation::{ValidatePassError, ValidationLevel}, }; +#[derive(Debug, Clone, Default)] +/// A configuration for the Constant Folding pass. pub struct ConstFoldPass { validation: ValidationLevel, /// If true, allow to skip evaluating loops (whose results are not needed) even if @@ -32,77 +33,12 @@ pub struct ConstFoldPass { pub allow_skip_loops: bool, } -struct ConstFoldContext<'a, H>(&'a mut H); - -impl<'a, T: HugrView> AsRef for ConstFoldContext<'a, T> { - fn as_ref(&self) -> &hugr_core::Hugr { - self.0.base_hugr() - } -} - -impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { - fn value_from_opaque( - &self, - node: Node, - fields: &[usize], - val: &OpaqueValue, - ) -> Option { - Some(ValueHandle::new_opaque(node, fields, val.clone())) - } - - fn value_from_const_hugr( - &self, - node: Node, - fields: &[usize], - h: &hugr_core::Hugr, - ) -> Option { - Some(ValueHandle::new_const_hugr( - node, - fields, - Box::new(h.clone()), - )) - } - - fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { - if type_args.len() > 0 { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(self, node).ok()?; - Some(ValueHandle::new_const_hugr( - node, - &[], - Box::new(func.extract_hugr()), - )) - } -} - -impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { - type InterpretableVal = Value; - - fn interpret_leaf_op( - &self, - n: Node, - op: &ExtensionOp, - ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort, PartialValue)> { - let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); - op.constant_fold(&ins).map_or(Vec::new(), |outs| { - outs.into_iter() - .map(|(p, v)| { - ( - p, - self.value_from_const(n, &v), // Hmmm, should (at least) also key by p - ) - }) - .collect() - }) +impl ConstFoldPass { + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self } -} -impl ConstFoldPass { /// Run the Constant Folding pass. fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { let results = Machine::default().run(ConstFoldContext(hugr), []); @@ -205,3 +141,81 @@ impl ConstFoldPass { } } } + +/// Exhaustively apply constant folding to a HUGR. +pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { + ConstFoldPass::default().run(h, reg).unwrap() +} + +struct ConstFoldContext<'a, H>(&'a mut H); + +impl<'a, T: HugrView> AsRef for ConstFoldContext<'a, T> { + fn as_ref(&self) -> &hugr_core::Hugr { + self.0.base_hugr() + } +} + +impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { + fn value_from_opaque( + &self, + node: Node, + fields: &[usize], + val: &OpaqueValue, + ) -> Option { + Some(ValueHandle::new_opaque(node, fields, val.clone())) + } + + fn value_from_const_hugr( + &self, + node: Node, + fields: &[usize], + h: &hugr_core::Hugr, + ) -> Option { + Some(ValueHandle::new_const_hugr( + node, + fields, + Box::new(h.clone()), + )) + } + + fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { + if type_args.len() > 0 { + // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) + return None; + }; + // Returning the function body as a value, here, would be sufficient for inlining IndirectCall + // but not for transforming to a direct Call. + let func = DescendantsGraph::>::try_new(self, node).ok()?; + Some(ValueHandle::new_const_hugr( + node, + &[], + Box::new(func.extract_hugr()), + )) + } +} + +impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { + type InterpretableVal = Value; + + fn interpret_leaf_op( + &self, + n: Node, + op: &ExtensionOp, + ins: &[(IncomingPort, Value)], + ) -> Vec<(OutgoingPort, PartialValue)> { + let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); + op.constant_fold(&ins).map_or(Vec::new(), |outs| { + outs.into_iter() + .map(|(p, v)| { + ( + p, + self.value_from_const(n, &v), // Hmmm, should (at least) also key by p + ) + }) + .collect() + }) + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/const_fold2/test.rs b/hugr-passes/src/const_fold2/test.rs index f329bf075..a6e368ec2 100644 --- a/hugr-passes/src/const_fold2/test.rs +++ b/hugr-passes/src/const_fold2/test.rs @@ -1,15 +1,11 @@ -#![cfg(test)] - use hugr_core::ops::{constant::CustomConst, Value}; use hugr_core::std_extensions::arithmetic::{float_types::ConstF64, int_types::ConstInt}; use hugr_core::{types::SumType, Hugr, Node}; use itertools::Itertools; use rstest::rstest; -use crate::{ - const_fold2::ConstFoldContext, - dataflow::{ConstLoader, PartialValue}, -}; +use super::ConstFoldContext; +use crate::dataflow::{ConstLoader, PartialValue}; #[rstest] #[case(ConstInt::new_u(4, 2).unwrap(), true)] From e45a620b44f79598bf81821ee2338904f9dd6f60 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 16:18:26 +0000 Subject: [PATCH 177/281] copy all old const_fold tests over --- hugr-passes/src/const_fold2/test.rs | 1521 ++++++++++++++++++++++++++- 1 file changed, 1516 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/const_fold2/test.rs b/hugr-passes/src/const_fold2/test.rs index a6e368ec2..110ae0633 100644 --- a/hugr-passes/src/const_fold2/test.rs +++ b/hugr-passes/src/const_fold2/test.rs @@ -1,11 +1,28 @@ -use hugr_core::ops::{constant::CustomConst, Value}; -use hugr_core::std_extensions::arithmetic::{float_types::ConstF64, int_types::ConstInt}; -use hugr_core::{types::SumType, Hugr, Node}; +use hugr_core::builder::{inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr}; +use hugr_core::extension::prelude::{ + const_ok, sum_with_error, ConstError, ConstString, UnpackTuple, BOOL_T, ERROR_TYPE, STRING_TYPE, +}; +use hugr_core::extension::{ExtensionRegistry, PRELUDE}; +use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::ops::{constant::CustomConst, OpType, Value}; +use hugr_core::std_extensions::arithmetic::{ + self, + conversions::ConvertOpDef, + float_ops::FloatOps, + float_types::{ConstF64, FLOAT64_TYPE}, + int_ops::IntOpDef, + int_types::{ConstInt, INT_TYPES}, +}; +use hugr_core::std_extensions::logic::{self, LogicOp}; +use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV}; +use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; + use itertools::Itertools; +use lazy_static::lazy_static; use rstest::rstest; -use super::ConstFoldContext; -use crate::dataflow::{ConstLoader, PartialValue}; +use super::{constant_fold_pass, ConstFoldContext}; +use crate::dataflow::{ConstLoader, DFContext, PartialValue}; #[rstest] #[case(ConstInt::new_u(4, 2).unwrap(), true)] @@ -39,3 +56,1497 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { assert_ne!(v1_subfield, v2); } } + +/// Check that a hugr just loads and returns a single expected constant. +pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { + assert_fully_folded_with(h, |v| v == expected_value) +} + +/// Check that a hugr just loads and returns a single constant, and validate +/// that constant using `check_value`. +/// +/// [CustomConst::equals_const] is not required to be implemented. Use this +/// function for Values containing such a `CustomConst`. +fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { + let mut node_count = 0; + + for node in h.children(h.root()) { + let op = h.get_optype(node); + match op { + OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, + OpType::Const(c) if check_value(c.value()) => node_count += 1, + _ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()), + } + } + + assert_eq!(node_count, 4); +} + +/// int to constant +fn i2c(b: u64) -> Value { + Value::extension(ConstInt::new_u(5, b).unwrap()) +} + +/// float to constant +fn f2c(f: f64) -> Value { + ConstF64::new(f).into() +} + +#[rstest] +#[case(0.0, 0.0, 0.0)] +#[case(0.0, 1.0, 1.0)] +#[case(23.5, 435.5, 459.0)] +// c = a + b +fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { + let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); + let mut temp = Hugr::default(); + let ctx = ConstFoldContext(&mut temp); + let v_a = ctx.value_from_const(n_a, &f2c(a)); + let v_b = ctx.value_from_const(n_b, &f2c(b)); + assert_eq!(v_a.clone().try_into_value(&FLOAT64_TYPE), Ok(f2c(a))); + assert_eq!(v_b.clone().try_into_value(&FLOAT64_TYPE), Ok(f2c(b))); + + let mut outs = [PartialValue::Bottom]; + let OpType::ExtensionOp(add_op) = OpType::from(FloatOps::fadd) else { + panic!() + }; + ctx.interpret_leaf_op(n, &add_op, &[v_a, v_b], &mut outs); + + assert_eq!(outs[0].clone().try_into_value(&FLOAT64_TYPE), Ok(f2c(c))); +} + +fn noargfn(outputs: impl Into) -> Signature { + inout_sig(type_row![], outputs) +} + +#[test] +fn test_big() { + /* + Test approximately calculates + let x = (5.6, 3.2); + int(x.0 - x.1) == 2 + */ + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); + + let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); + + let unpack = build + .add_dataflow_op( + UnpackTuple::new(type_row![FLOAT64_TYPE, FLOAT64_TYPE]), + [tup], + ) + .unwrap(); + + let sub = build + .add_dataflow_op(FloatOps::fsub, unpack.outputs()) + .unwrap(); + let to_int = build + .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) + .unwrap(); + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build + .finish_hugr_with_outputs(to_int.outputs(), ®) + .unwrap(); + assert_eq!(h.node_count(), 8); + + constant_fold_pass(&mut h, ®); + + let expected = const_ok(i2c(2).clone(), ERROR_TYPE); + assert_fully_folded(&h, &expected); +} + +#[test] +#[ignore = "Waiting for `unwrap` operation"] +// TODO: https://github.com/CQCL/hugr/issues/1486 +fn test_list_ops() -> Result<(), Box> { + use hugr_core::std_extensions::collections::{self, ListOp, ListValue}; + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + collections::EXTENSION.to_owned(), + ]) + .unwrap(); + let base_list: Value = ListValue::new(BOOL_T, [Value::false_val()]).into(); + let mut build = DFGBuilder::new(Signature::new( + type_row![], + vec![base_list.get_type().clone()], + )) + .unwrap(); + + let list = build.add_load_const(base_list.clone()); + + let [list, maybe_elem] = build + .add_dataflow_op( + ListOp::pop.with_type(BOOL_T).to_extension_op(®).unwrap(), + [list], + )? + .outputs_arr(); + + // FIXME: Unwrap the Option + let elem = maybe_elem; + + let [list] = build + .add_dataflow_op( + ListOp::push + .with_type(BOOL_T) + .to_extension_op(®) + .unwrap(), + [list, elem], + )? + .outputs_arr(); + + let mut h = build.finish_hugr_with_outputs([list], ®)?; + + constant_fold_pass(&mut h, ®); + + assert_fully_folded(&h, &base_list); + Ok(()) +} + +#[test] +fn test_fold_and() { + // pseudocode: + // x0, x1 := bool(true), bool(true) + // x2 := and(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::true_val()); + let x2 = build.add_dataflow_op(LogicOp::And, [x0, x1]).unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_or() { + // pseudocode: + // x0, x1 := bool(true), bool(false) + // x2 := or(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::false_val()); + let x2 = build.add_dataflow_op(LogicOp::Or, [x0, x1]).unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_not() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // output x1 == false; + let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_dataflow_op(LogicOp::Not, [x0]).unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn orphan_output() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // x2 := or(x0,x1) + // output x2 == true; + // + // We arrange things so that the `or` folds away first, leaving the not + // with no outputs. + use hugr_core::ops::handle::NodeHandle; + + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let true_wire = build.add_load_value(Value::true_val()); + // this Not will be manually replaced + let orig_not = build.add_dataflow_op(LogicOp::Not, [true_wire]).unwrap(); + let r = build + .add_dataflow_op(LogicOp::Or, [true_wire, orig_not.out_wire(0)]) + .unwrap(); + let or_node = r.node(); + let parent = build.container_node(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); + + // we delete the original Not and create a new One. This means it will be + // traversed by `constant_fold_pass` after the Or. + let new_not = h.add_node_with_parent(parent, LogicOp::Not); + h.connect(true_wire.node(), true_wire.source(), new_not, 0); + h.disconnect(or_node, IncomingPort::from(1)); + h.connect(new_not, 0, or_node, 1); + h.remove_node(orig_not.node()); + constant_fold_pass(&mut h, ®); + assert_fully_folded(&h, &Value::true_val()) +} + +#[test] +fn test_folding_pass_issue_996() { + // pseudocode: + // + // x0 := 3.0 + // x1 := 4.0 + // x2 := fne(x0, x1); // true + // x3 := flt(x0, x1); // true + // x4 := and(x2, x3); // true + // x5 := -10.0 + // x6 := flt(x0, x5) // false + // x7 := or(x4, x6) // true + // output x7 + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); + let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); + let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); + let x4 = build + .add_dataflow_op(LogicOp::And, x2.outputs().chain(x3.outputs())) + .unwrap(); + let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); + let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); + let x7 = build + .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_const_fold_to_nonfinite() { + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + + // HUGR computing 1.0 / 1.0 + let mut build = DFGBuilder::new(noargfn(vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h0, ®); + assert_fully_folded_with(&h0, |v| { + v.get_custom_value::().unwrap().value() == 1.0 + }); + assert_eq!(h0.node_count(), 5); + + // HUGR computing 1.0 / 0.0 + let mut build = DFGBuilder::new(noargfn(vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h1, ®); + assert_eq!(h1.node_count(), 8); +} + +#[test] +fn test_fold_iwiden_u() { + // pseudocode: + // + // x0 := int_u<4>(13); + // x1 := iwiden_u<4, 5>(x0); + // output x1 == int_u<5>(13); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(4, 13).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(4, 5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 13).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_iwiden_s() { + // pseudocode: + // + // x0 := int_u<4>(-3); + // x1 := iwiden_u<4, 5>(x0); + // output x1 == int_s<5>(-3); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(4, -3).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::iwiden_s.with_two_log_widths(4, 5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[rstest] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 4, -3, true)] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 5, -3, true)] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 1, -3, false)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 4, 13, true)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 5, 13, true)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 0, 3, false)] +fn test_fold_inarrow, E: std::fmt::Debug>( + #[case] mk_const: impl Fn(u8, I) -> Result, + #[case] op_def: IntOpDef, + #[case] from_log_width: u8, + #[case] to_log_width: u8, + #[case] val: I, + #[case] succeeds: bool, +) { + // For the first case, pseudocode: + // + // x0 := int_s<5>(-3); + // x1 := inarrow_s<5, 4>(x0); + // output x1 == sum(-3)]>; + // + // Other cases vary by: + // (mk_const, op_def) => create signed or unsigned constants, create + // inarrow_s or inarrow_u ops; + // (from_log_width, to_log_width) => the args to use to create the op; + // val => the value to pass to the op + // succeeds => whether to expect a int variant or an error + // variant. + + use hugr_core::extension::prelude::const_ok; + let elem_type = INT_TYPES[to_log_width as usize].to_owned(); + let sum_type = sum_with_error(elem_type.clone()); + let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); + let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into()); + let x1 = build + .add_dataflow_op( + op_def.with_two_log_widths(from_log_width, to_log_width), + [x0], + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + lazy_static! { + static ref INARROW_ERROR_VALUE: ConstError = ConstError { + signal: 0, + message: "Integer too large to narrow".to_string(), + }; + } + let expected = if succeeds { + const_ok(mk_const(to_log_width, val).unwrap().into(), ERROR_TYPE) + } else { + INARROW_ERROR_VALUE.clone().as_either(elem_type) + }; + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_itobool() { + // pseudocode: + // + // x0 := int_u<0>(1); + // x1 := itobool(x0); + // output x1 == true; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(0, 1).unwrap())); + let x1 = build + .add_dataflow_op(ConvertOpDef::itobool.without_log_width(), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ifrombool() { + // pseudocode: + // + // x0 := false + // x1 := ifrombool(x0); + // output x1 == int_u<0>(0); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[0].clone()])).unwrap(); + let x0 = build.add_load_const(Value::false_val()); + let x1 = build + .add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(0, 0).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ieq() { + // pseudocode: + // x0, x1 := int_s<3>(-1), int_u<3>(255) + // x2 := ieq(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(3, -1).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 255).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ieq.with_log_width(3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ine() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(4) + // x2 := ine(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ilt_u() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(4) + // x2 := ilt_u(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ilt_s() { + // pseudocode: + // x0, x1 := int_s<5>(3), int_s<5>(-4) + // x2 := ilt_s(x0, x1) + // output x2 == false; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_igt_u() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(4) + // x2 := ilt_u(x0, x1) + // output x2 == false; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::igt_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_igt_s() { + // pseudocode: + // x0, x1 := int_s<5>(3), int_s<5>(-4) + // x2 := ilt_s(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::igt_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ile_u() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(3) + // x2 := ile_u(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ile_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ile_s() { + // pseudocode: + // x0, x1 := int_s<5>(-4), int_s<5>(-4) + // x2 := ile_s(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ile_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ige_u() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(4) + // x2 := ilt_u(x0, x1) + // output x2 == false; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ige_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ige_s() { + // pseudocode: + // x0, x1 := int_s<5>(3), int_s<5>(-4) + // x2 := ilt_s(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ige_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imax_u() { + // pseudocode: + // x0, x1 := int_u<5>(7), int_u<5>(11); + // x2 := imax_u(x0, x1); + // output x2 == int_u<5>(11); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imax_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 11).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imax_s() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(1); + // x2 := imax_u(x0, x1); + // output x2 == int_s<5>(1); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imax_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imin_u() { + // pseudocode: + // x0, x1 := int_u<5>(7), int_u<5>(11); + // x2 := imin_u(x0, x1); + // output x2 == int_u<5>(7); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imin_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 7).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imin_s() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(1); + // x2 := imin_u(x0, x1); + // output x2 == int_s<5>(-2); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imin_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -2).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_iadd() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(1); + // x2 := iadd(x0, x1); + // output x2 == int_s<5>(-1); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::iadd.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_isub() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(1); + // x2 := isub(x0, x1); + // output x2 == int_s<5>(-3); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::isub.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ineg() { + // pseudocode: + // x0 := int_s<5>(-2); + // x1 := ineg(x0); + // output x1 == int_s<5>(2); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ineg.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, 2).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imul() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(7); + // x2 := imul(x0, x1); + // output x2 == int_s<5>(-14); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 7).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imul.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -14).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idivmod_checked_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<5>(0) + // x2 := idivmod_checked_u(x0, x1) + // output x2 == error + let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); + let elem_type = Type::new_tuple(intpair); + let sum_type = sum_with_error(elem_type.clone()); + let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idivmod_checked_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .as_either(elem_type); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idivmod_u() { + // pseudocode: + // x0, x1 := int_u<3>(20), int_u<3>(3); + // x2, x3 := idivmod_u(x0, x1); // 6, 2 + // x4 := iadd<3>(x2, x3); // 8 + // output x4 == int_u<5>(8); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[3].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(3, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let [x2, x3] = build + .add_dataflow_op(IntOpDef::idivmod_u.with_log_width(3), [x0, x1]) + .unwrap() + .outputs_arr(); + let x4 = build + .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [x2, x3]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(3, 8).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idivmod_checked_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<5>(0) + // x2 := idivmod_checked_s(x0, x1) + // output x2 == error + let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); + let elem_type = Type::new_tuple(intpair); + let sum_type = sum_with_error(elem_type.clone()); + let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idivmod_checked_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .as_either(elem_type); + assert_fully_folded(&h, &expected); +} + +#[rstest] +#[case(20, 3, 8)] +#[case(-20, 3, -6)] +#[case(-20, 4, -5)] +#[case(i64::MIN, 1, i64::MIN)] +#[case(i64::MIN, 2, -(1i64 << 62))] +#[case(i64::MIN, 1u64 << 63, -1)] +// c = a/b + a%b +fn test_fold_idivmod_s(#[case] a: i64, #[case] b: u64, #[case] c: i64) { + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[6].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(6, a).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(6, b).unwrap())); + let [x2, x3] = build + .add_dataflow_op(IntOpDef::idivmod_s.with_log_width(6), [x0, x1]) + .unwrap() + .outputs_arr(); + let x4 = build + .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [x2, x3]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(6, c).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idiv_checked_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<5>(0) + // x2 := idiv_checked_u(x0, x1) + // output x2 == error + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idiv_checked_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .as_either(INT_TYPES[5].to_owned()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idiv_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<5>(3); + // x2 := idiv_u(x0, x1); + // output x2 == int_u<5>(6); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idiv_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 6).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imod_checked_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<5>(0) + // x2 := imod_checked_u(x0, x1) + // output x2 == error + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imod_checked_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .as_either(INT_TYPES[5].to_owned()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imod_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<5>(3); + // x2 := imod_u(x0, x1); + // output x2 == int_u<3>(2); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imod_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idiv_checked_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<5>(0) + // x2 := idiv_checked_s(x0, x1) + // output x2 == error + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idiv_checked_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .as_either(INT_TYPES[5].to_owned()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idiv_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<5>(3); + // x2 := idiv_s(x0, x1); + // output x2 == int_s<5>(-7); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idiv_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -7).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imod_checked_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<5>(0) + // x2 := imod_checked_u(x0, x1) + // output x2 == error + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imod_checked_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .as_either(INT_TYPES[5].to_owned()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imod_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<5>(3); + // x2 := imod_s(x0, x1); + // output x2 == int_u<5>(1); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imod_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_iabs() { + // pseudocode: + // x0 := int_s<5>(-2); + // x1 := iabs(x0); + // output x1 == int_s<5>(2); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::iabs.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_iand() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<5>(20); + // x2 := iand(x0, x1); + // output x2 == int_u<5>(4); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::iand.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 4).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ior() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<5>(20); + // x2 := ior(x0, x1); + // output x2 == int_u<5>(30); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ior.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 30).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ixor() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<5>(20); + // x2 := ixor(x0, x1); + // output x2 == int_u<5>(26); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ixor.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 26).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_inot() { + // pseudocode: + // x0 := int_u<5>(14); + // x1 := inot(x0); + // output x1 == int_u<5>(17); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::inot.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, (1u64 << 32) - 15).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ishl() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<3>(3); + // x2 := ishl(x0, x1); + // output x2 == int_u<5>(112); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ishl.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 112).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ishr() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<3>(3); + // x2 := ishr(x0, x1); + // output x2 == int_u<5>(1); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ishr.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_irotl() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<3>(61); + // x2 := irotl(x0, x1); + // output x2 == int_u<5>(2^30 + 2^31 + 1); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 61).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::irotl.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_irotr() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<3>(3); + // x2 := irotr(x0, x1); + // output x2 == int_u<5>(2^30 + 2^31 + 1); + let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::irotr.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_itostring_u() { + // pseudocode: + // x0 := int_u<5>(17); + // x1 := itostring_u(x0); + // output x2 := "17"; + let mut build = DFGBuilder::new(noargfn(vec![STRING_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 17).unwrap())); + let x1 = build + .add_dataflow_op(ConvertOpDef::itostring_u.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstString::new("17".into())); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_itostring_s() { + // pseudocode: + // x0 := int_s<5>(-17); + // x1 := itostring_s(x0); + // output x2 := "-17"; + let mut build = DFGBuilder::new(noargfn(vec![STRING_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -17).unwrap())); + let x1 = build + .add_dataflow_op(ConvertOpDef::itostring_s.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstString::new("-17".into())); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_int_ops() { + // pseudocode: + // + // x0 := int_u<5>(3); // 3 + // x1 := int_u<5>(4); // 4 + // x2 := ine(x0, x1); // true + // x3 := ilt_u(x0, x1); // true + // x4 := and(x2, x3); // true + // x5 := int_s<5>(-10) // -10 + // x6 := ilt_s(x0, x5) // false + // x7 := or(x4, x6) // true + // output x7 + let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) + .unwrap(); + let x3 = build + .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) + .unwrap(); + let x4 = build + .add_dataflow_op(LogicOp::And, x2.outputs().chain(x3.outputs())) + .unwrap(); + let x5 = build.add_load_const(Value::extension(ConstInt::new_s(5, -10).unwrap())); + let x6 = build + .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x5]) + .unwrap(); + let x7 = build + .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} From e64f51934479ee5ce3ab8951e4d2ce2dceacd803 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 16:47:07 +0000 Subject: [PATCH 178/281] tests: Fix test_add --- hugr-passes/src/const_fold2/test.rs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/const_fold2/test.rs b/hugr-passes/src/const_fold2/test.rs index 110ae0633..8e594a8af 100644 --- a/hugr-passes/src/const_fold2/test.rs +++ b/hugr-passes/src/const_fold2/test.rs @@ -21,7 +21,7 @@ use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; -use super::{constant_fold_pass, ConstFoldContext}; +use super::{constant_fold_pass, ConstFoldContext, ValueHandle}; use crate::dataflow::{ConstLoader, DFContext, PartialValue}; #[rstest] @@ -98,13 +98,20 @@ fn f2c(f: f64) -> Value { #[case(23.5, 435.5, 459.0)] // c = a + b fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { + fn unwrap_float(pv: PartialValue) -> f64 { + pv.try_into_value::(&FLOAT64_TYPE) + .unwrap() + .get_custom_value::() + .unwrap() + .value() + } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); let mut temp = Hugr::default(); let ctx = ConstFoldContext(&mut temp); let v_a = ctx.value_from_const(n_a, &f2c(a)); let v_b = ctx.value_from_const(n_b, &f2c(b)); - assert_eq!(v_a.clone().try_into_value(&FLOAT64_TYPE), Ok(f2c(a))); - assert_eq!(v_b.clone().try_into_value(&FLOAT64_TYPE), Ok(f2c(b))); + assert_eq!(unwrap_float(v_a.clone()), a); + assert_eq!(unwrap_float(v_b.clone()), b); let mut outs = [PartialValue::Bottom]; let OpType::ExtensionOp(add_op) = OpType::from(FloatOps::fadd) else { @@ -112,7 +119,7 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { }; ctx.interpret_leaf_op(n, &add_op, &[v_a, v_b], &mut outs); - assert_eq!(outs[0].clone().try_into_value(&FLOAT64_TYPE), Ok(f2c(c))); + assert_eq!(unwrap_float(outs[0].clone()), c); } fn noargfn(outputs: impl Into) -> Signature { From 58a1df23ce4bd3569ec229f602967bb1eb55dfa3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 16:46:35 +0000 Subject: [PATCH 179/281] fix: Improve find_needed_nodes --- hugr-passes/src/const_fold2.rs | 81 ++++++++++++++++------------------ 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 49d7bf457..30a037dd4 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -88,54 +88,51 @@ impl ConstFoldPass { fn find_needed_nodes( &self, results: &AnalysisResults, - container: Node, + root: Node, needed: &mut HashSet, ) { + let mut q = VecDeque::new(); + q.push_back(root); let h = &results.hugr; - if h.get_optype(container).is_cfg() { - for bb in h.children(container) { - if results.bb_reachable(bb).unwrap() - && needed.insert(bb) - && h.get_optype(bb).is_dataflow_block() - { - self.find_needed_nodes(results, bb, needed); - } - } - } else { - // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. - let [_inp, outp] = h.get_io(container).unwrap(); - let mut q = VecDeque::new(); - q.push_back(outp); - // Add on anything that might not terminate. We might also allow a custom predicate for extension ops? - for n in h.children(container) { - if h.get_optype(n).is_cfg() - || (!self.allow_skip_loops - && h.get_optype(n).is_tail_loop() - && results.tail_loop_terminates(n).unwrap() - != TailLoopTermination::NeverContinues) - { - q.push_back(n); - } - } - while let Some(n) = q.pop_front() { - if !needed.insert(n) { - continue; + while let Some(n) = q.pop_front() { + if !needed.insert(n) { + continue; + }; + if h.get_optype(n).is_cfg() { + for bb in h.children(n) { + if results.bb_reachable(bb).unwrap() + && needed.insert(bb) + && h.get_optype(bb).is_dataflow_block() + { + q.push_back(bb); + } } - for (src, op) in h.all_linked_outputs(n) { - let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { - EdgeKind::Value(_) => { - results.try_read_wire_value(Wire::new(src, op)).is_err() - } - EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, - EdgeKind::ControlFlow => panic!(), - _ => true, // needed for non-exhaustive; not knowing what it is, assume the worst - }; - if needs_predecessor { - q.push_back(src); + } else if let Some(inout) = h.get_io(n) { + // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. + q.extend(inout); // Input also necessary for legality even if unreachable + + // Also add on anything that might not terminate. We might also allow a custom predicate for extension ops? + for ch in h.children(n) { + if h.get_optype(ch).is_cfg() + || (!self.allow_skip_loops + && h.get_optype(ch).is_tail_loop() + && results.tail_loop_terminates(ch).unwrap() + != TailLoopTermination::NeverContinues) + { + q.push_back(ch); } } - if h.get_optype(n).is_container() { - self.find_needed_nodes(results, container, needed); + } + // Also follow dataflow demand + for (src, op) in h.all_linked_outputs(n) { + let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { + EdgeKind::Value(_) => results.try_read_wire_value(Wire::new(src, op)).is_err(), + EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, + EdgeKind::ControlFlow => panic!(), + _ => true, // needed for non-exhaustive; not knowing what it is, assume the worst + }; + if needs_predecessor { + q.push_back(src); } } } From 6687656d5f6c5141236bfd1f73cb5e5ebae40c35 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 16:46:53 +0000 Subject: [PATCH 180/281] fix: Connect constant to use --- hugr-passes/src/const_fold2.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 30a037dd4..a43682282 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -66,6 +66,8 @@ impl ConstFoldPass { let cst = hugr_mut.add_node_with_parent(parent, Const::new(v)); let lcst = hugr_mut.add_node_with_parent(parent, LoadConstant { datatype }); hugr_mut.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); + hugr_mut.disconnect(n, inport); + hugr_mut.connect(lcst, OutgoingPort::from(0), n, inport); } } } From c7ba482b4babefe764ab1ad285277d3a0a02a2c3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 17:07:02 +0000 Subject: [PATCH 181/281] Don't break LoadConstant out_wires; skip unnecessary cloning --- hugr-passes/src/const_fold2.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index a43682282..240e4f73a 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -58,6 +58,9 @@ impl ConstFoldPass { EdgeKind::Value(_) ) { let (src, outp) = results.hugr.single_linked_output(n, inport).unwrap(); + if results.hugr.get_optype(src).is_load_constant() { + continue; + } if let Ok(v) = results.try_read_wire_value(Wire::new(src, outp)) { let parent = results.hugr.get_parent(n).unwrap(); let datatype = v.get_type(); @@ -128,7 +131,10 @@ impl ConstFoldPass { // Also follow dataflow demand for (src, op) in h.all_linked_outputs(n) { let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { - EdgeKind::Value(_) => results.try_read_wire_value(Wire::new(src, op)).is_err(), + EdgeKind::Value(_) => { + results.hugr.get_optype(src).is_load_constant() + || results.try_read_wire_value(Wire::new(src, op)).is_err() + } EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, EdgeKind::ControlFlow => panic!(), _ => true, // needed for non-exhaustive; not knowing what it is, assume the worst @@ -202,7 +208,6 @@ impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { op: &ExtensionOp, ins: &[(IncomingPort, Value)], ) -> Vec<(OutgoingPort, PartialValue)> { - let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); op.constant_fold(&ins).map_or(Vec::new(), |outs| { outs.into_iter() .map(|(p, v)| { From 7484fb992e6ffed09a27d5ff5fad7c4b96dfc11c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 17:36:36 +0000 Subject: [PATCH 182/281] Remove old const_fold and rename over --- hugr-passes/src/const_fold.rs | 362 ++-- hugr-passes/src/const_fold/test.rs | 93 +- .../value_handle.rs | 0 hugr-passes/src/const_fold2.rs | 225 --- hugr-passes/src/const_fold2/test.rs | 1559 ----------------- hugr-passes/src/lib.rs | 1 - 6 files changed, 253 insertions(+), 1987 deletions(-) rename hugr-passes/src/{const_fold2 => const_fold}/value_handle.rs (100%) delete mode 100644 hugr-passes/src/const_fold2.rs delete mode 100644 hugr-passes/src/const_fold2/test.rs diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index b5e303d43..240e4f73a 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -1,220 +1,224 @@ -//! Constant folding routines. +//! An (example) use of the [super::dataflow](dataflow-analysis framework) +//! to perform constant-folding. -use std::collections::{BTreeSet, HashMap}; +// These are pub because this "example" is used for testing the framework. +pub mod value_handle; +use std::collections::{HashSet, VecDeque}; -use hugr_core::builder::inout_sig; -use itertools::Itertools; -use thiserror::Error; - -use hugr_core::hugr::SimpleReplacementError; -use hugr_core::types::SumType; -use hugr_core::Direction; use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{fold_out_row, ConstFoldResult, ExtensionRegistry}, + extension::ExtensionRegistry, hugr::{ hugrmut::HugrMut, - rewrite::consts::{RemoveConst, RemoveLoadConstant}, - views::SiblingSubgraph, + views::{DescendantsGraph, ExtractHugr, HierarchyView}, }, - ops::{OpType, Value}, - type_row, Hugr, HugrView, IncomingPort, Node, SimpleReplacement, + ops::{constant::OpaqueValue, handle::FuncID, Const, ExtensionOp, LoadConstant, Value}, + types::{EdgeKind, TypeArg}, + HugrView, IncomingPort, Node, OutgoingPort, Wire, }; +use value_handle::ValueHandle; -use crate::validation::{ValidatePassError, ValidationLevel}; - -#[derive(Error, Debug)] -#[allow(missing_docs)] -pub enum ConstFoldError { - #[error(transparent)] - SimpleReplacementError(#[from] SimpleReplacementError), - #[error(transparent)] - ValidationError(#[from] ValidatePassError), -} +use crate::{ + dataflow::{ + AnalysisResults, ConstLoader, Machine, PartialValue, TailLoopTermination, TotalContext, + }, + validation::{ValidatePassError, ValidationLevel}, +}; -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. -pub struct ConstantFoldPass { +pub struct ConstFoldPass { validation: ValidationLevel, + /// If true, allow to skip evaluating loops (whose results are not needed) even if + /// we are not sure they will terminate. (If they definitely terminate then fair game.) + pub allow_skip_loops: bool, } -impl ConstantFoldPass { - /// Create a new `ConstFoldConfig` with default configuration. - pub fn new() -> Self { - Self::default() - } - - /// Build a `ConstFoldConfig` with the given [ValidationLevel]. +impl ConstFoldPass { pub fn validation_level(mut self, level: ValidationLevel) -> Self { self.validation = level; self } /// Run the Constant Folding pass. + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + let results = Machine::default().run(ConstFoldContext(hugr), []); + let mut keep_nodes = HashSet::new(); + self.find_needed_nodes(&results, results.hugr.root(), &mut keep_nodes); + + let remove_nodes = results + .hugr + .nodes() + .filter(|n| !keep_nodes.contains(n)) + .collect::>(); + for n in keep_nodes { + // Every input either (a) is in keep_nodes, or (b) has a known value. Break all wires (b). + for inport in results.hugr.node_inputs(n) { + if matches!( + results.hugr.get_optype(n).port_kind(inport).unwrap(), + EdgeKind::Value(_) + ) { + let (src, outp) = results.hugr.single_linked_output(n, inport).unwrap(); + if results.hugr.get_optype(src).is_load_constant() { + continue; + } + if let Ok(v) = results.try_read_wire_value(Wire::new(src, outp)) { + let parent = results.hugr.get_parent(n).unwrap(); + let datatype = v.get_type(); + // We could try hash-consing identical Consts, but not ATM + let hugr_mut = &mut *results.hugr.0; + let cst = hugr_mut.add_node_with_parent(parent, Const::new(v)); + let lcst = hugr_mut.add_node_with_parent(parent, LoadConstant { datatype }); + hugr_mut.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); + hugr_mut.disconnect(n, inport); + hugr_mut.connect(lcst, OutgoingPort::from(0), n, inport); + } + } + } + } + for n in remove_nodes { + hugr.remove_node(n); + } + Ok(()) + } + pub fn run( &self, hugr: &mut H, reg: &ExtensionRegistry, - ) -> Result<(), ConstFoldError> { + ) -> Result<(), ValidatePassError> { self.validation - .run_validated_pass(hugr, reg, |hugr: &mut H, _| { - loop { - // We can only safely apply a single replacement. Applying a - // replacement removes nodes and edges which may be referenced by - // further replacements returned by find_consts. Even worse, if we - // attempted to apply those replacements, expecting them to fail if - // the nodes and edges they reference had been deleted, they may - // succeed because new nodes and edges reused the ids. - // - // We could be a lot smarter here, keeping track of `LoadConstant` - // nodes and only looking at their out neighbours. - let Some((replace, removes)) = find_consts(hugr, hugr.nodes(), reg).next() - else { - break Ok(()); - }; - hugr.apply_rewrite(replace)?; - for rem in removes { - // We are optimistically applying these [RemoveLoadConstant] and - // [RemoveConst] rewrites without checking whether the nodes - // they attempt to remove have remaining uses. If they do, then - // the rewrite fails and we move on. - if let Ok(const_node) = hugr.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - let _ = hugr.apply_rewrite(RemoveConst(const_node)); - } + .run_validated_pass(hugr, reg, |hugr: &mut H, _| self.run_no_validate(hugr)) + } + + fn find_needed_nodes( + &self, + results: &AnalysisResults, + root: Node, + needed: &mut HashSet, + ) { + let mut q = VecDeque::new(); + q.push_back(root); + let h = &results.hugr; + while let Some(n) = q.pop_front() { + if !needed.insert(n) { + continue; + }; + if h.get_optype(n).is_cfg() { + for bb in h.children(n) { + if results.bb_reachable(bb).unwrap() + && needed.insert(bb) + && h.get_optype(bb).is_dataflow_block() + { + q.push_back(bb); } } - }) + } else if let Some(inout) = h.get_io(n) { + // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. + q.extend(inout); // Input also necessary for legality even if unreachable + + // Also add on anything that might not terminate. We might also allow a custom predicate for extension ops? + for ch in h.children(n) { + if h.get_optype(ch).is_cfg() + || (!self.allow_skip_loops + && h.get_optype(ch).is_tail_loop() + && results.tail_loop_terminates(ch).unwrap() + != TailLoopTermination::NeverContinues) + { + q.push_back(ch); + } + } + } + // Also follow dataflow demand + for (src, op) in h.all_linked_outputs(n) { + let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { + EdgeKind::Value(_) => { + results.hugr.get_optype(src).is_load_constant() + || results.try_read_wire_value(Wire::new(src, op)).is_err() + } + EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, + EdgeKind::ControlFlow => panic!(), + _ => true, // needed for non-exhaustive; not knowing what it is, assume the worst + }; + if needs_predecessor { + q.push_back(src); + } + } + } } } -/// For a given op and consts, attempt to evaluate the op. -pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - let fold_result = match op { - OpType::Tag(t) => fold_out_row([Value::sum( - t.tag, - consts.iter().map(|(_, konst)| konst.clone()), - SumType::new(t.variants.clone()), - ) - .unwrap()]), - OpType::ExtensionOp(ext_op) => ext_op.constant_fold(consts), - _ => None, - }; - debug_assert!(fold_result.as_ref().map_or(true, |x| x.len() - == op.value_port_count(Direction::Outgoing))); - fold_result +/// Exhaustively apply constant folding to a HUGR. +pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { + ConstFoldPass::default().run(h, reg).unwrap() } -/// Generate a graph that loads and outputs `consts` in order, validating -/// against `reg`. -fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { - let const_types = consts.iter().map(Value::get_type).collect_vec(); - let mut b = DFGBuilder::new(inout_sig(type_row![], const_types)).unwrap(); - - let outputs = consts - .into_iter() - .map(|c| b.add_load_const(c)) - .collect_vec(); +struct ConstFoldContext<'a, H>(&'a mut H); - b.finish_hugr_with_outputs(outputs, reg).unwrap() +impl<'a, T: HugrView> AsRef for ConstFoldContext<'a, T> { + fn as_ref(&self) -> &hugr_core::Hugr { + self.0.base_hugr() + } } -/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, -/// return an iterator of possible constant folding rewrites. -/// -/// The [`SimpleReplacement`] replaces an operation with constants that result from -/// evaluating it, the extension registry `reg` is used to validate the -/// replacement HUGR. The vector of [`RemoveLoadConstant`] refer to the -/// LoadConstant nodes that could be removed - they are not automatically -/// removed as they may be used by other operations. -pub fn find_consts<'a, 'r: 'a>( - hugr: &'a impl HugrView, - candidate_nodes: impl IntoIterator + 'a, - reg: &'r ExtensionRegistry, -) -> impl Iterator)> + 'a { - // track nodes for operations that have already been considered for folding - let mut used_neighbours = BTreeSet::new(); - - candidate_nodes - .into_iter() - .filter_map(move |n| { - // only look at LoadConstant - hugr.get_optype(n).is_load_constant().then_some(())?; - - let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; - let neighbours = hugr - .linked_inputs(n, out_p) - .filter(|(n, _)| used_neighbours.insert(*n)) - .collect_vec(); - if neighbours.is_empty() { - // no uses of LoadConstant that haven't already been considered. - return None; - } - let fold_iter = neighbours - .into_iter() - .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); - Some(fold_iter) - }) - .flatten() -} +impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { + fn value_from_opaque( + &self, + node: Node, + fields: &[usize], + val: &OpaqueValue, + ) -> Option { + Some(ValueHandle::new_opaque(node, fields, val.clone())) + } -/// Attempt to evaluate and generate rewrites for the operation at `op_node` -fn fold_op( - hugr: &impl HugrView, - op_node: Node, - reg: &ExtensionRegistry, -) -> Option<(SimpleReplacement, Vec)> { - // only support leaf folding for now. - let neighbour_op = hugr.get_optype(op_node); - let (in_consts, removals): (Vec<_>, Vec<_>) = hugr - .node_inputs(op_node) - .filter_map(|in_p| { - let (con_op, load_n) = get_const(hugr, op_node, in_p)?; - Some(((in_p, con_op), RemoveLoadConstant(load_n))) - }) - .unzip(); - // attempt to evaluate op - let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)? - .into_iter() - .enumerate() - .filter_map(|(i, (op_out, konst))| { - // for each used port of the op give the nu_out entry and the - // corresponding Value - hugr.single_linked_input(op_node, op_out) - .map(|np| ((np, i.into()), konst)) - }) - .unzip(); - let replacement = const_graph(consts, reg); - let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) - .expect("Operation should form valid subgraph."); - - let simple_replace = SimpleReplacement::new( - sibling_graph, - replacement, - // no inputs to replacement - HashMap::new(), - nu_out, - ); - Some((simple_replace, removals)) -} + fn value_from_const_hugr( + &self, + node: Node, + fields: &[usize], + h: &hugr_core::Hugr, + ) -> Option { + Some(ValueHandle::new_const_hugr( + node, + fields, + Box::new(h.clone()), + )) + } -/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant -/// and the LoadConstant node -fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Value, Node)> { - let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; - let load_op = hugr.get_optype(load_n).as_load_constant()?; - let const_node = hugr - .single_linked_output(load_n, load_op.constant_port())? - .0; - let const_op = hugr.get_optype(const_node).as_const()?; - - // TODO avoid const clone here - Some((const_op.as_ref().clone(), load_n)) + fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { + if type_args.len() > 0 { + // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) + return None; + }; + // Returning the function body as a value, here, would be sufficient for inlining IndirectCall + // but not for transforming to a direct Call. + let func = DescendantsGraph::>::try_new(self, node).ok()?; + Some(ValueHandle::new_const_hugr( + node, + &[], + Box::new(func.extract_hugr()), + )) + } } -/// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - ConstantFoldPass::default().run(h, reg).unwrap() +impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { + type InterpretableVal = Value; + + fn interpret_leaf_op( + &self, + n: Node, + op: &ExtensionOp, + ins: &[(IncomingPort, Value)], + ) -> Vec<(OutgoingPort, PartialValue)> { + op.constant_fold(&ins).map_or(Vec::new(), |outs| { + outs.into_iter() + .map(|(p, v)| { + ( + p, + self.value_from_const(n, &v), // Hmmm, should (at least) also key by p + ) + }) + .collect() + }) + } } #[cfg(test)] diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 58ee38bcf..8e594a8af 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,27 +1,61 @@ -use crate::const_fold::constant_fold_pass; -use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr}; +use hugr_core::builder::{inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{ const_ok, sum_with_error, ConstError, ConstString, UnpackTuple, BOOL_T, ERROR_TYPE, STRING_TYPE, }; use hugr_core::extension::{ExtensionRegistry, PRELUDE}; -use hugr_core::ops::Value; -use hugr_core::std_extensions::arithmetic; -use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; -use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; +use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::ops::{constant::CustomConst, OpType, Value}; +use hugr_core::std_extensions::arithmetic::{ + self, + conversions::ConvertOpDef, + float_ops::FloatOps, + float_types::{ConstF64, FLOAT64_TYPE}, + int_ops::IntOpDef, + int_types::{ConstInt, INT_TYPES}, +}; use hugr_core::std_extensions::logic::{self, LogicOp}; -use hugr_core::type_row; -use hugr_core::types::{Signature, Type, TypeRow, TypeRowRV}; +use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV}; +use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; +use itertools::Itertools; +use lazy_static::lazy_static; use rstest::rstest; -use lazy_static::lazy_static; +use super::{constant_fold_pass, ConstFoldContext, ValueHandle}; +use crate::dataflow::{ConstLoader, DFContext, PartialValue}; -use super::*; -use hugr_core::builder::Container; -use hugr_core::ops::OpType; -use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; -use hugr_core::std_extensions::arithmetic::float_ops::FloatOps; -use hugr_core::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; +#[rstest] +#[case(ConstInt::new_u(4, 2).unwrap(), true)] +#[case(ConstF64::new(std::f64::consts::PI), false)] +fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { + let n = Node::from(portgraph::NodeIndex::new(7)); + let st = SumType::new([vec![k.get_type()], vec![]]); + let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); + let mut temp = Hugr::default(); + let ctx: ConstFoldContext = ConstFoldContext(&mut temp); + let v1 = ctx.value_from_const(n, &subject_val); + + let v1_subfield = { + let PartialValue::PartialSum(ps1) = v1 else { + panic!() + }; + ps1.0 + .into_iter() + .exactly_one() + .unwrap() + .1 + .into_iter() + .exactly_one() + .unwrap() + }; + + let v2 = ctx.value_from_const(n, &k.into()); + if eq { + assert_eq!(v1_subfield, v2); + } else { + assert_ne!(v1_subfield, v2); + } +} /// Check that a hugr just loads and returns a single expected constant. pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { @@ -64,15 +98,28 @@ fn f2c(f: f64) -> Value { #[case(23.5, 435.5, 459.0)] // c = a + b fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { - let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; - let add_op: OpType = FloatOps::fadd.into(); - let outs = fold_leaf_op(&add_op, &consts) - .unwrap() - .into_iter() - .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) - .collect_vec(); + fn unwrap_float(pv: PartialValue) -> f64 { + pv.try_into_value::(&FLOAT64_TYPE) + .unwrap() + .get_custom_value::() + .unwrap() + .value() + } + let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); + let mut temp = Hugr::default(); + let ctx = ConstFoldContext(&mut temp); + let v_a = ctx.value_from_const(n_a, &f2c(a)); + let v_b = ctx.value_from_const(n_b, &f2c(b)); + assert_eq!(unwrap_float(v_a.clone()), a); + assert_eq!(unwrap_float(v_b.clone()), b); + + let mut outs = [PartialValue::Bottom]; + let OpType::ExtensionOp(add_op) = OpType::from(FloatOps::fadd) else { + panic!() + }; + ctx.interpret_leaf_op(n, &add_op, &[v_a, v_b], &mut outs); - assert_eq!(outs.as_slice(), &[(0.into(), c)]); + assert_eq!(unwrap_float(outs[0].clone()), c); } fn noargfn(outputs: impl Into) -> Signature { diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs similarity index 100% rename from hugr-passes/src/const_fold2/value_handle.rs rename to hugr-passes/src/const_fold/value_handle.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs deleted file mode 100644 index 240e4f73a..000000000 --- a/hugr-passes/src/const_fold2.rs +++ /dev/null @@ -1,225 +0,0 @@ -//! An (example) use of the [super::dataflow](dataflow-analysis framework) -//! to perform constant-folding. - -// These are pub because this "example" is used for testing the framework. -pub mod value_handle; -use std::collections::{HashSet, VecDeque}; - -use hugr_core::{ - extension::ExtensionRegistry, - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, - ops::{constant::OpaqueValue, handle::FuncID, Const, ExtensionOp, LoadConstant, Value}, - types::{EdgeKind, TypeArg}, - HugrView, IncomingPort, Node, OutgoingPort, Wire, -}; -use value_handle::ValueHandle; - -use crate::{ - dataflow::{ - AnalysisResults, ConstLoader, Machine, PartialValue, TailLoopTermination, TotalContext, - }, - validation::{ValidatePassError, ValidationLevel}, -}; - -#[derive(Debug, Clone, Default)] -/// A configuration for the Constant Folding pass. -pub struct ConstFoldPass { - validation: ValidationLevel, - /// If true, allow to skip evaluating loops (whose results are not needed) even if - /// we are not sure they will terminate. (If they definitely terminate then fair game.) - pub allow_skip_loops: bool, -} - -impl ConstFoldPass { - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - - /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - let results = Machine::default().run(ConstFoldContext(hugr), []); - let mut keep_nodes = HashSet::new(); - self.find_needed_nodes(&results, results.hugr.root(), &mut keep_nodes); - - let remove_nodes = results - .hugr - .nodes() - .filter(|n| !keep_nodes.contains(n)) - .collect::>(); - for n in keep_nodes { - // Every input either (a) is in keep_nodes, or (b) has a known value. Break all wires (b). - for inport in results.hugr.node_inputs(n) { - if matches!( - results.hugr.get_optype(n).port_kind(inport).unwrap(), - EdgeKind::Value(_) - ) { - let (src, outp) = results.hugr.single_linked_output(n, inport).unwrap(); - if results.hugr.get_optype(src).is_load_constant() { - continue; - } - if let Ok(v) = results.try_read_wire_value(Wire::new(src, outp)) { - let parent = results.hugr.get_parent(n).unwrap(); - let datatype = v.get_type(); - // We could try hash-consing identical Consts, but not ATM - let hugr_mut = &mut *results.hugr.0; - let cst = hugr_mut.add_node_with_parent(parent, Const::new(v)); - let lcst = hugr_mut.add_node_with_parent(parent, LoadConstant { datatype }); - hugr_mut.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); - hugr_mut.disconnect(n, inport); - hugr_mut.connect(lcst, OutgoingPort::from(0), n, inport); - } - } - } - } - for n in remove_nodes { - hugr.remove_node(n); - } - Ok(()) - } - - pub fn run( - &self, - hugr: &mut H, - reg: &ExtensionRegistry, - ) -> Result<(), ValidatePassError> { - self.validation - .run_validated_pass(hugr, reg, |hugr: &mut H, _| self.run_no_validate(hugr)) - } - - fn find_needed_nodes( - &self, - results: &AnalysisResults, - root: Node, - needed: &mut HashSet, - ) { - let mut q = VecDeque::new(); - q.push_back(root); - let h = &results.hugr; - while let Some(n) = q.pop_front() { - if !needed.insert(n) { - continue; - }; - if h.get_optype(n).is_cfg() { - for bb in h.children(n) { - if results.bb_reachable(bb).unwrap() - && needed.insert(bb) - && h.get_optype(bb).is_dataflow_block() - { - q.push_back(bb); - } - } - } else if let Some(inout) = h.get_io(n) { - // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. - q.extend(inout); // Input also necessary for legality even if unreachable - - // Also add on anything that might not terminate. We might also allow a custom predicate for extension ops? - for ch in h.children(n) { - if h.get_optype(ch).is_cfg() - || (!self.allow_skip_loops - && h.get_optype(ch).is_tail_loop() - && results.tail_loop_terminates(ch).unwrap() - != TailLoopTermination::NeverContinues) - { - q.push_back(ch); - } - } - } - // Also follow dataflow demand - for (src, op) in h.all_linked_outputs(n) { - let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { - EdgeKind::Value(_) => { - results.hugr.get_optype(src).is_load_constant() - || results.try_read_wire_value(Wire::new(src, op)).is_err() - } - EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, - EdgeKind::ControlFlow => panic!(), - _ => true, // needed for non-exhaustive; not knowing what it is, assume the worst - }; - if needs_predecessor { - q.push_back(src); - } - } - } - } -} - -/// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - ConstFoldPass::default().run(h, reg).unwrap() -} - -struct ConstFoldContext<'a, H>(&'a mut H); - -impl<'a, T: HugrView> AsRef for ConstFoldContext<'a, T> { - fn as_ref(&self) -> &hugr_core::Hugr { - self.0.base_hugr() - } -} - -impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { - fn value_from_opaque( - &self, - node: Node, - fields: &[usize], - val: &OpaqueValue, - ) -> Option { - Some(ValueHandle::new_opaque(node, fields, val.clone())) - } - - fn value_from_const_hugr( - &self, - node: Node, - fields: &[usize], - h: &hugr_core::Hugr, - ) -> Option { - Some(ValueHandle::new_const_hugr( - node, - fields, - Box::new(h.clone()), - )) - } - - fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { - if type_args.len() > 0 { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(self, node).ok()?; - Some(ValueHandle::new_const_hugr( - node, - &[], - Box::new(func.extract_hugr()), - )) - } -} - -impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { - type InterpretableVal = Value; - - fn interpret_leaf_op( - &self, - n: Node, - op: &ExtensionOp, - ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort, PartialValue)> { - op.constant_fold(&ins).map_or(Vec::new(), |outs| { - outs.into_iter() - .map(|(p, v)| { - ( - p, - self.value_from_const(n, &v), // Hmmm, should (at least) also key by p - ) - }) - .collect() - }) - } -} - -#[cfg(test)] -mod test; diff --git a/hugr-passes/src/const_fold2/test.rs b/hugr-passes/src/const_fold2/test.rs deleted file mode 100644 index 8e594a8af..000000000 --- a/hugr-passes/src/const_fold2/test.rs +++ /dev/null @@ -1,1559 +0,0 @@ -use hugr_core::builder::{inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr}; -use hugr_core::extension::prelude::{ - const_ok, sum_with_error, ConstError, ConstString, UnpackTuple, BOOL_T, ERROR_TYPE, STRING_TYPE, -}; -use hugr_core::extension::{ExtensionRegistry, PRELUDE}; -use hugr_core::hugr::hugrmut::HugrMut; -use hugr_core::ops::{constant::CustomConst, OpType, Value}; -use hugr_core::std_extensions::arithmetic::{ - self, - conversions::ConvertOpDef, - float_ops::FloatOps, - float_types::{ConstF64, FLOAT64_TYPE}, - int_ops::IntOpDef, - int_types::{ConstInt, INT_TYPES}, -}; -use hugr_core::std_extensions::logic::{self, LogicOp}; -use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV}; -use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; - -use itertools::Itertools; -use lazy_static::lazy_static; -use rstest::rstest; - -use super::{constant_fold_pass, ConstFoldContext, ValueHandle}; -use crate::dataflow::{ConstLoader, DFContext, PartialValue}; - -#[rstest] -#[case(ConstInt::new_u(4, 2).unwrap(), true)] -#[case(ConstF64::new(std::f64::consts::PI), false)] -fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { - let n = Node::from(portgraph::NodeIndex::new(7)); - let st = SumType::new([vec![k.get_type()], vec![]]); - let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let mut temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&mut temp); - let v1 = ctx.value_from_const(n, &subject_val); - - let v1_subfield = { - let PartialValue::PartialSum(ps1) = v1 else { - panic!() - }; - ps1.0 - .into_iter() - .exactly_one() - .unwrap() - .1 - .into_iter() - .exactly_one() - .unwrap() - }; - - let v2 = ctx.value_from_const(n, &k.into()); - if eq { - assert_eq!(v1_subfield, v2); - } else { - assert_ne!(v1_subfield, v2); - } -} - -/// Check that a hugr just loads and returns a single expected constant. -pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { - assert_fully_folded_with(h, |v| v == expected_value) -} - -/// Check that a hugr just loads and returns a single constant, and validate -/// that constant using `check_value`. -/// -/// [CustomConst::equals_const] is not required to be implemented. Use this -/// function for Values containing such a `CustomConst`. -fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { - let mut node_count = 0; - - for node in h.children(h.root()) { - let op = h.get_optype(node); - match op { - OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, - OpType::Const(c) if check_value(c.value()) => node_count += 1, - _ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()), - } - } - - assert_eq!(node_count, 4); -} - -/// int to constant -fn i2c(b: u64) -> Value { - Value::extension(ConstInt::new_u(5, b).unwrap()) -} - -/// float to constant -fn f2c(f: f64) -> Value { - ConstF64::new(f).into() -} - -#[rstest] -#[case(0.0, 0.0, 0.0)] -#[case(0.0, 1.0, 1.0)] -#[case(23.5, 435.5, 459.0)] -// c = a + b -fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { - fn unwrap_float(pv: PartialValue) -> f64 { - pv.try_into_value::(&FLOAT64_TYPE) - .unwrap() - .get_custom_value::() - .unwrap() - .value() - } - let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let mut temp = Hugr::default(); - let ctx = ConstFoldContext(&mut temp); - let v_a = ctx.value_from_const(n_a, &f2c(a)); - let v_b = ctx.value_from_const(n_b, &f2c(b)); - assert_eq!(unwrap_float(v_a.clone()), a); - assert_eq!(unwrap_float(v_b.clone()), b); - - let mut outs = [PartialValue::Bottom]; - let OpType::ExtensionOp(add_op) = OpType::from(FloatOps::fadd) else { - panic!() - }; - ctx.interpret_leaf_op(n, &add_op, &[v_a, v_b], &mut outs); - - assert_eq!(unwrap_float(outs[0].clone()), c); -} - -fn noargfn(outputs: impl Into) -> Signature { - inout_sig(type_row![], outputs) -} - -#[test] -fn test_big() { - /* - Test approximately calculates - let x = (5.6, 3.2); - int(x.0 - x.1) == 2 - */ - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); - - let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); - - let unpack = build - .add_dataflow_op( - UnpackTuple::new(type_row![FLOAT64_TYPE, FLOAT64_TYPE]), - [tup], - ) - .unwrap(); - - let sub = build - .add_dataflow_op(FloatOps::fsub, unpack.outputs()) - .unwrap(); - let to_int = build - .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) - .unwrap(); - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - arithmetic::conversions::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build - .finish_hugr_with_outputs(to_int.outputs(), ®) - .unwrap(); - assert_eq!(h.node_count(), 8); - - constant_fold_pass(&mut h, ®); - - let expected = const_ok(i2c(2).clone(), ERROR_TYPE); - assert_fully_folded(&h, &expected); -} - -#[test] -#[ignore = "Waiting for `unwrap` operation"] -// TODO: https://github.com/CQCL/hugr/issues/1486 -fn test_list_ops() -> Result<(), Box> { - use hugr_core::std_extensions::collections::{self, ListOp, ListValue}; - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - collections::EXTENSION.to_owned(), - ]) - .unwrap(); - let base_list: Value = ListValue::new(BOOL_T, [Value::false_val()]).into(); - let mut build = DFGBuilder::new(Signature::new( - type_row![], - vec![base_list.get_type().clone()], - )) - .unwrap(); - - let list = build.add_load_const(base_list.clone()); - - let [list, maybe_elem] = build - .add_dataflow_op( - ListOp::pop.with_type(BOOL_T).to_extension_op(®).unwrap(), - [list], - )? - .outputs_arr(); - - // FIXME: Unwrap the Option - let elem = maybe_elem; - - let [list] = build - .add_dataflow_op( - ListOp::push - .with_type(BOOL_T) - .to_extension_op(®) - .unwrap(), - [list, elem], - )? - .outputs_arr(); - - let mut h = build.finish_hugr_with_outputs([list], ®)?; - - constant_fold_pass(&mut h, ®); - - assert_fully_folded(&h, &base_list); - Ok(()) -} - -#[test] -fn test_fold_and() { - // pseudocode: - // x0, x1 := bool(true), bool(true) - // x2 := and(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_load_const(Value::true_val()); - let x2 = build.add_dataflow_op(LogicOp::And, [x0, x1]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_or() { - // pseudocode: - // x0, x1 := bool(true), bool(false) - // x2 := or(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_load_const(Value::false_val()); - let x2 = build.add_dataflow_op(LogicOp::Or, [x0, x1]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_not() { - // pseudocode: - // x0 := bool(true) - // x1 := not(x0) - // output x1 == false; - let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_dataflow_op(LogicOp::Not, [x0]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::false_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn orphan_output() { - // pseudocode: - // x0 := bool(true) - // x1 := not(x0) - // x2 := or(x0,x1) - // output x2 == true; - // - // We arrange things so that the `or` folds away first, leaving the not - // with no outputs. - use hugr_core::ops::handle::NodeHandle; - - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let true_wire = build.add_load_value(Value::true_val()); - // this Not will be manually replaced - let orig_not = build.add_dataflow_op(LogicOp::Not, [true_wire]).unwrap(); - let r = build - .add_dataflow_op(LogicOp::Or, [true_wire, orig_not.out_wire(0)]) - .unwrap(); - let or_node = r.node(); - let parent = build.container_node(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); - - // we delete the original Not and create a new One. This means it will be - // traversed by `constant_fold_pass` after the Or. - let new_not = h.add_node_with_parent(parent, LogicOp::Not); - h.connect(true_wire.node(), true_wire.source(), new_not, 0); - h.disconnect(or_node, IncomingPort::from(1)); - h.connect(new_not, 0, or_node, 1); - h.remove_node(orig_not.node()); - constant_fold_pass(&mut h, ®); - assert_fully_folded(&h, &Value::true_val()) -} - -#[test] -fn test_folding_pass_issue_996() { - // pseudocode: - // - // x0 := 3.0 - // x1 := 4.0 - // x2 := fne(x0, x1); // true - // x3 := flt(x0, x1); // true - // x4 := and(x2, x3); // true - // x5 := -10.0 - // x6 := flt(x0, x5) // false - // x7 := or(x4, x6) // true - // output x7 - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); - let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); - let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); - let x4 = build - .add_dataflow_op(LogicOp::And, x2.outputs().chain(x3.outputs())) - .unwrap(); - let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); - let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); - let x7 = build - .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_const_fold_to_nonfinite() { - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - - // HUGR computing 1.0 / 1.0 - let mut build = DFGBuilder::new(noargfn(vec![FLOAT64_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h0, ®); - assert_fully_folded_with(&h0, |v| { - v.get_custom_value::().unwrap().value() == 1.0 - }); - assert_eq!(h0.node_count(), 5); - - // HUGR computing 1.0 / 0.0 - let mut build = DFGBuilder::new(noargfn(vec![FLOAT64_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); - let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h1, ®); - assert_eq!(h1.node_count(), 8); -} - -#[test] -fn test_fold_iwiden_u() { - // pseudocode: - // - // x0 := int_u<4>(13); - // x1 := iwiden_u<4, 5>(x0); - // output x1 == int_u<5>(13); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(4, 13).unwrap())); - let x1 = build - .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(4, 5), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 13).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_iwiden_s() { - // pseudocode: - // - // x0 := int_u<4>(-3); - // x1 := iwiden_u<4, 5>(x0); - // output x1 == int_s<5>(-3); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(4, -3).unwrap())); - let x1 = build - .add_dataflow_op(IntOpDef::iwiden_s.with_two_log_widths(4, 5), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[rstest] -#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 4, -3, true)] -#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 5, -3, true)] -#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 1, -3, false)] -#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 4, 13, true)] -#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 5, 13, true)] -#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 0, 3, false)] -fn test_fold_inarrow, E: std::fmt::Debug>( - #[case] mk_const: impl Fn(u8, I) -> Result, - #[case] op_def: IntOpDef, - #[case] from_log_width: u8, - #[case] to_log_width: u8, - #[case] val: I, - #[case] succeeds: bool, -) { - // For the first case, pseudocode: - // - // x0 := int_s<5>(-3); - // x1 := inarrow_s<5, 4>(x0); - // output x1 == sum(-3)]>; - // - // Other cases vary by: - // (mk_const, op_def) => create signed or unsigned constants, create - // inarrow_s or inarrow_u ops; - // (from_log_width, to_log_width) => the args to use to create the op; - // val => the value to pass to the op - // succeeds => whether to expect a int variant or an error - // variant. - - use hugr_core::extension::prelude::const_ok; - let elem_type = INT_TYPES[to_log_width as usize].to_owned(); - let sum_type = sum_with_error(elem_type.clone()); - let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); - let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into()); - let x1 = build - .add_dataflow_op( - op_def.with_two_log_widths(from_log_width, to_log_width), - [x0], - ) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - lazy_static! { - static ref INARROW_ERROR_VALUE: ConstError = ConstError { - signal: 0, - message: "Integer too large to narrow".to_string(), - }; - } - let expected = if succeeds { - const_ok(mk_const(to_log_width, val).unwrap().into(), ERROR_TYPE) - } else { - INARROW_ERROR_VALUE.clone().as_either(elem_type) - }; - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_itobool() { - // pseudocode: - // - // x0 := int_u<0>(1); - // x1 := itobool(x0); - // output x1 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(0, 1).unwrap())); - let x1 = build - .add_dataflow_op(ConvertOpDef::itobool.without_log_width(), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ifrombool() { - // pseudocode: - // - // x0 := false - // x1 := ifrombool(x0); - // output x1 == int_u<0>(0); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[0].clone()])).unwrap(); - let x0 = build.add_load_const(Value::false_val()); - let x1 = build - .add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(0, 0).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ieq() { - // pseudocode: - // x0, x1 := int_s<3>(-1), int_u<3>(255) - // x2 := ieq(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(3, -1).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 255).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ieq.with_log_width(3), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ine() { - // pseudocode: - // x0, x1 := int_u<5>(3), int_u<5>(4) - // x2 := ine(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ilt_u() { - // pseudocode: - // x0, x1 := int_u<5>(3), int_u<5>(4) - // x2 := ilt_u(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ilt_s() { - // pseudocode: - // x0, x1 := int_s<5>(3), int_s<5>(-4) - // x2 := ilt_s(x0, x1) - // output x2 == false; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::false_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_igt_u() { - // pseudocode: - // x0, x1 := int_u<5>(3), int_u<5>(4) - // x2 := ilt_u(x0, x1) - // output x2 == false; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::igt_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::false_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_igt_s() { - // pseudocode: - // x0, x1 := int_s<5>(3), int_s<5>(-4) - // x2 := ilt_s(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::igt_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ile_u() { - // pseudocode: - // x0, x1 := int_u<5>(3), int_u<5>(3) - // x2 := ile_u(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ile_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ile_s() { - // pseudocode: - // x0, x1 := int_s<5>(-4), int_s<5>(-4) - // x2 := ile_s(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ile_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ige_u() { - // pseudocode: - // x0, x1 := int_u<5>(3), int_u<5>(4) - // x2 := ilt_u(x0, x1) - // output x2 == false; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ige_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::false_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ige_s() { - // pseudocode: - // x0, x1 := int_s<5>(3), int_s<5>(-4) - // x2 := ilt_s(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ige_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imax_u() { - // pseudocode: - // x0, x1 := int_u<5>(7), int_u<5>(11); - // x2 := imax_u(x0, x1); - // output x2 == int_u<5>(11); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imax_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 11).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imax_s() { - // pseudocode: - // x0, x1 := int_s<5>(-2), int_s<5>(1); - // x2 := imax_u(x0, x1); - // output x2 == int_s<5>(1); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imax_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(5, 1).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imin_u() { - // pseudocode: - // x0, x1 := int_u<5>(7), int_u<5>(11); - // x2 := imin_u(x0, x1); - // output x2 == int_u<5>(7); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imin_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 7).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imin_s() { - // pseudocode: - // x0, x1 := int_s<5>(-2), int_s<5>(1); - // x2 := imin_u(x0, x1); - // output x2 == int_s<5>(-2); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imin_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(5, -2).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_iadd() { - // pseudocode: - // x0, x1 := int_s<5>(-2), int_s<5>(1); - // x2 := iadd(x0, x1); - // output x2 == int_s<5>(-1); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::iadd.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(5, -1).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_isub() { - // pseudocode: - // x0, x1 := int_s<5>(-2), int_s<5>(1); - // x2 := isub(x0, x1); - // output x2 == int_s<5>(-3); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::isub.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ineg() { - // pseudocode: - // x0 := int_s<5>(-2); - // x1 := ineg(x0); - // output x1 == int_s<5>(2); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ineg.with_log_width(5), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(5, 2).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imul() { - // pseudocode: - // x0, x1 := int_s<5>(-2), int_s<5>(7); - // x2 := imul(x0, x1); - // output x2 == int_s<5>(-14); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 7).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imul.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(5, -14).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_idivmod_checked_u() { - // pseudocode: - // x0, x1 := int_u<5>(20), int_u<5>(0) - // x2 := idivmod_checked_u(x0, x1) - // output x2 == error - let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); - let elem_type = Type::new_tuple(intpair); - let sum_type = sum_with_error(elem_type.clone()); - let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::idivmod_checked_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = ConstError { - signal: 0, - message: "Division by zero".to_string(), - } - .as_either(elem_type); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_idivmod_u() { - // pseudocode: - // x0, x1 := int_u<3>(20), int_u<3>(3); - // x2, x3 := idivmod_u(x0, x1); // 6, 2 - // x4 := iadd<3>(x2, x3); // 8 - // output x4 == int_u<5>(8); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[3].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(3, 20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); - let [x2, x3] = build - .add_dataflow_op(IntOpDef::idivmod_u.with_log_width(3), [x0, x1]) - .unwrap() - .outputs_arr(); - let x4 = build - .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [x2, x3]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(3, 8).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_idivmod_checked_s() { - // pseudocode: - // x0, x1 := int_s<5>(-20), int_u<5>(0) - // x2 := idivmod_checked_s(x0, x1) - // output x2 == error - let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); - let elem_type = Type::new_tuple(intpair); - let sum_type = sum_with_error(elem_type.clone()); - let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::idivmod_checked_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = ConstError { - signal: 0, - message: "Division by zero".to_string(), - } - .as_either(elem_type); - assert_fully_folded(&h, &expected); -} - -#[rstest] -#[case(20, 3, 8)] -#[case(-20, 3, -6)] -#[case(-20, 4, -5)] -#[case(i64::MIN, 1, i64::MIN)] -#[case(i64::MIN, 2, -(1i64 << 62))] -#[case(i64::MIN, 1u64 << 63, -1)] -// c = a/b + a%b -fn test_fold_idivmod_s(#[case] a: i64, #[case] b: u64, #[case] c: i64) { - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[6].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(6, a).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(6, b).unwrap())); - let [x2, x3] = build - .add_dataflow_op(IntOpDef::idivmod_s.with_log_width(6), [x0, x1]) - .unwrap() - .outputs_arr(); - let x4 = build - .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [x2, x3]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(6, c).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_idiv_checked_u() { - // pseudocode: - // x0, x1 := int_u<5>(20), int_u<5>(0) - // x2 := idiv_checked_u(x0, x1) - // output x2 == error - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::idiv_checked_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = ConstError { - signal: 0, - message: "Division by zero".to_string(), - } - .as_either(INT_TYPES[5].to_owned()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_idiv_u() { - // pseudocode: - // x0, x1 := int_u<5>(20), int_u<5>(3); - // x2 := idiv_u(x0, x1); - // output x2 == int_u<5>(6); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::idiv_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 6).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imod_checked_u() { - // pseudocode: - // x0, x1 := int_u<5>(20), int_u<5>(0) - // x2 := imod_checked_u(x0, x1) - // output x2 == error - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imod_checked_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = ConstError { - signal: 0, - message: "Division by zero".to_string(), - } - .as_either(INT_TYPES[5].to_owned()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imod_u() { - // pseudocode: - // x0, x1 := int_u<5>(20), int_u<5>(3); - // x2 := imod_u(x0, x1); - // output x2 == int_u<3>(2); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imod_u.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_idiv_checked_s() { - // pseudocode: - // x0, x1 := int_s<5>(-20), int_u<5>(0) - // x2 := idiv_checked_s(x0, x1) - // output x2 == error - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::idiv_checked_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = ConstError { - signal: 0, - message: "Division by zero".to_string(), - } - .as_either(INT_TYPES[5].to_owned()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_idiv_s() { - // pseudocode: - // x0, x1 := int_s<5>(-20), int_u<5>(3); - // x2 := idiv_s(x0, x1); - // output x2 == int_s<5>(-7); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::idiv_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(5, -7).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imod_checked_s() { - // pseudocode: - // x0, x1 := int_s<5>(-20), int_u<5>(0) - // x2 := imod_checked_u(x0, x1) - // output x2 == error - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imod_checked_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = ConstError { - signal: 0, - message: "Division by zero".to_string(), - } - .as_either(INT_TYPES[5].to_owned()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_imod_s() { - // pseudocode: - // x0, x1 := int_s<5>(-20), int_u<5>(3); - // x2 := imod_s(x0, x1); - // output x2 == int_u<5>(1); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::imod_s.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_iabs() { - // pseudocode: - // x0 := int_s<5>(-2); - // x1 := iabs(x0); - // output x1 == int_s<5>(2); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::iabs.with_log_width(5), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_iand() { - // pseudocode: - // x0, x1 := int_u<5>(14), int_u<5>(20); - // x2 := iand(x0, x1); - // output x2 == int_u<5>(4); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::iand.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 4).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ior() { - // pseudocode: - // x0, x1 := int_u<5>(14), int_u<5>(20); - // x2 := ior(x0, x1); - // output x2 == int_u<5>(30); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ior.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 30).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ixor() { - // pseudocode: - // x0, x1 := int_u<5>(14), int_u<5>(20); - // x2 := ixor(x0, x1); - // output x2 == int_u<5>(26); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ixor.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 26).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_inot() { - // pseudocode: - // x0 := int_u<5>(14); - // x1 := inot(x0); - // output x1 == int_u<5>(17); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::inot.with_log_width(5), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, (1u64 << 32) - 15).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ishl() { - // pseudocode: - // x0, x1 := int_u<5>(14), int_u<3>(3); - // x2 := ishl(x0, x1); - // output x2 == int_u<5>(112); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ishl.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 112).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_ishr() { - // pseudocode: - // x0, x1 := int_u<5>(14), int_u<3>(3); - // x2 := ishr(x0, x1); - // output x2 == int_u<5>(1); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ishr.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_irotl() { - // pseudocode: - // x0, x1 := int_u<5>(14), int_u<3>(61); - // x2 := irotl(x0, x1); - // output x2 == int_u<5>(2^30 + 2^31 + 1); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 61).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::irotl.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_irotr() { - // pseudocode: - // x0, x1 := int_u<5>(14), int_u<3>(3); - // x2 := irotr(x0, x1); - // output x2 == int_u<5>(2^30 + 2^31 + 1); - let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::irotr.with_log_width(5), [x0, x1]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_itostring_u() { - // pseudocode: - // x0 := int_u<5>(17); - // x1 := itostring_u(x0); - // output x2 := "17"; - let mut build = DFGBuilder::new(noargfn(vec![STRING_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 17).unwrap())); - let x1 = build - .add_dataflow_op(ConvertOpDef::itostring_u.with_log_width(5), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstString::new("17".into())); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_itostring_s() { - // pseudocode: - // x0 := int_s<5>(-17); - // x1 := itostring_s(x0); - // output x2 := "-17"; - let mut build = DFGBuilder::new(noargfn(vec![STRING_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -17).unwrap())); - let x1 = build - .add_dataflow_op(ConvertOpDef::itostring_s.with_log_width(5), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstString::new("-17".into())); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_int_ops() { - // pseudocode: - // - // x0 := int_u<5>(3); // 3 - // x1 := int_u<5>(4); // 4 - // x2 := ine(x0, x1); // true - // x3 := ilt_u(x0, x1); // true - // x4 := and(x2, x3); // true - // x5 := int_s<5>(-10) // -10 - // x6 := ilt_s(x0, x5) // false - // x7 := or(x4, x6) // true - // output x7 - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); - let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); - let x2 = build - .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) - .unwrap(); - let x3 = build - .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) - .unwrap(); - let x4 = build - .add_dataflow_op(LogicOp::And, x2.outputs().chain(x3.outputs())) - .unwrap(); - let x5 = build.add_load_const(Value::extension(ConstInt::new_s(5, -10).unwrap())); - let x6 = build - .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x5]) - .unwrap(); - let x7 = build - .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); -} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 0b73fcbb0..06781f7c5 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,7 +1,6 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; -pub mod const_fold2; pub mod dataflow; pub mod force_order; mod half_node; From 55f9700c7463f4d5de92ad177127e24f656ce297 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 17:48:23 +0000 Subject: [PATCH 183/281] Revert accidental changes to views.rs --- hugr-core/src/hugr/views.rs | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index fc90df10d..6a52f33f0 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -528,6 +528,14 @@ impl RootTagged for Hugr { type RootHandle = Node; } +impl RootTagged for &Hugr { + type RootHandle = Node; +} + +impl RootTagged for &mut Hugr { + type RootHandle = Node; +} + // Explicit implementation to avoid cloning the Hugr. impl ExtractHugr for Hugr { fn extract_hugr(self) -> Hugr { @@ -547,20 +555,6 @@ impl ExtractHugr for &mut Hugr { } } -impl<'a, H: RootTagged> RootTagged for &'a H -where - &'a H: HugrView, -{ - type RootHandle = H::RootHandle; -} - -impl<'a, H: RootTagged> RootTagged for &'a mut H -where - &'a mut H: HugrView, -{ - type RootHandle = H::RootHandle; -} - impl> HugrView for T { /// An Iterator over the nodes in a Hugr(View) type Nodes<'a> = MapInto, Node> where Self: 'a; From 494d8490156738e237710f2c1c32ca1ad3055eea Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 12:01:40 +0000 Subject: [PATCH 184/281] Hide AnalysisResults::hugr via accessor,destructor,two-step transform --- hugr-passes/src/const_fold.rs | 62 ++++++++++++++++------------- hugr-passes/src/dataflow/datalog.rs | 2 +- hugr-passes/src/dataflow/results.rs | 12 +++++- 3 files changed, 46 insertions(+), 30 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 240e4f73a..5db190a7f 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -43,40 +43,46 @@ impl ConstFoldPass { fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { let results = Machine::default().run(ConstFoldContext(hugr), []); let mut keep_nodes = HashSet::new(); - self.find_needed_nodes(&results, results.hugr.root(), &mut keep_nodes); + self.find_needed_nodes(&results, results.hugr().root(), &mut keep_nodes); let remove_nodes = results - .hugr + .hugr() .nodes() .filter(|n| !keep_nodes.contains(n)) .collect::>(); - for n in keep_nodes { - // Every input either (a) is in keep_nodes, or (b) has a known value. Break all wires (b). - for inport in results.hugr.node_inputs(n) { - if matches!( - results.hugr.get_optype(n).port_kind(inport).unwrap(), + let wires_to_break = keep_nodes + .into_iter() + .flat_map(|n| results.hugr().node_inputs(n).map(move |ip| (n, ip))) + .filter(|(n, ip)| { + matches!( + results.hugr().get_optype(*n).port_kind(*ip).unwrap(), EdgeKind::Value(_) - ) { - let (src, outp) = results.hugr.single_linked_output(n, inport).unwrap(); - if results.hugr.get_optype(src).is_load_constant() { - continue; - } - if let Ok(v) = results.try_read_wire_value(Wire::new(src, outp)) { - let parent = results.hugr.get_parent(n).unwrap(); - let datatype = v.get_type(); - // We could try hash-consing identical Consts, but not ATM - let hugr_mut = &mut *results.hugr.0; - let cst = hugr_mut.add_node_with_parent(parent, Const::new(v)); - let lcst = hugr_mut.add_node_with_parent(parent, LoadConstant { datatype }); - hugr_mut.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); - hugr_mut.disconnect(n, inport); - hugr_mut.connect(lcst, OutgoingPort::from(0), n, inport); - } - } - } + ) + }) + // Note we COULD filter out (avoid breaking) wires from other nodes that we are keeping. + // This would insert fewer constants, but potentially expose less parallelism. + .filter_map(|(n, ip)| { + let (src, outp) = results.hugr().single_linked_output(n, ip).unwrap(); + (!results.hugr().get_optype(src).is_load_constant()).then_some(( + n, + ip, + results.try_read_wire_value(Wire::new(src, outp)).ok()?, + )) + }) + .collect::>(); + let hugr_mut = results.into_hugr().0; // and drop 'results' + for (n, inport, v) in wires_to_break { + let parent = hugr_mut.get_parent(n).unwrap(); + let datatype = v.get_type(); + // We could try hash-consing identical Consts, but not ATM + let cst = hugr_mut.add_node_with_parent(parent, Const::new(v)); + let lcst = hugr_mut.add_node_with_parent(parent, LoadConstant { datatype }); + hugr_mut.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); + hugr_mut.disconnect(n, inport); + hugr_mut.connect(lcst, OutgoingPort::from(0), n, inport); } for n in remove_nodes { - hugr.remove_node(n); + hugr_mut.remove_node(n); } Ok(()) } @@ -98,7 +104,7 @@ impl ConstFoldPass { ) { let mut q = VecDeque::new(); q.push_back(root); - let h = &results.hugr; + let h = results.hugr(); while let Some(n) = q.pop_front() { if !needed.insert(n) { continue; @@ -132,7 +138,7 @@ impl ConstFoldPass { for (src, op) in h.all_linked_outputs(n) { let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { EdgeKind::Value(_) => { - results.hugr.get_optype(src).is_load_constant() + results.hugr().get_optype(src).is_load_constant() || results.try_read_wire_value(Wire::new(src, op)).is_err() } EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 2cc535d89..455b328e5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -308,7 +308,7 @@ fn propagate_leaf_op( } OpType::ExtensionOp(e) => { // Interpret op. - let init = if ins.iter().contains(&PartialValue::Bottom) { + let init = if ins.iter().contains(&PartialValue::Bottom) { // So far we think one or more inputs can't happen. // So, don't pollute outputs with Top, and wait for better knowledge of inputs. PartialValue::Bottom diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index a8800f755..5cddc5354 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -7,7 +7,7 @@ use super::{AbstractValue, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { - pub hugr: H, + pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, pub(super) bb_reachable: Vec<(Node, Node)>, @@ -15,6 +15,16 @@ pub struct AnalysisResults { } impl AnalysisResults { + /// Allows to use the [HugrView] contained within + pub fn hugr(&self) -> &H { + &self.hugr + } + + /// Discards the results, allowing to get back the [HugrView] within + pub fn into_hugr(self) -> H { + self.hugr + } + /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() From 25f2bb74ab5c8f3aa66b5df9048fa6e959af89e0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 12:23:58 +0000 Subject: [PATCH 185/281] Separate out DFContext::hugr(&self) -> &impl HugrView --- hugr-passes/src/const_fold.rs | 18 +++--- hugr-passes/src/dataflow.rs | 4 +- hugr-passes/src/dataflow/datalog.rs | 75 +++++++++++++---------- hugr-passes/src/dataflow/results.rs | 38 ++++++------ hugr-passes/src/dataflow/test.rs | 9 ++- hugr-passes/src/dataflow/total_context.rs | 8 ++- 6 files changed, 82 insertions(+), 70 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 5db190a7f..6e9cbb37a 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -70,7 +70,7 @@ impl ConstFoldPass { )) }) .collect::>(); - let hugr_mut = results.into_hugr().0; // and drop 'results' + let hugr_mut = results.into_ctx().0; // and drop 'results' for (n, inport, v) in wires_to_break { let parent = hugr_mut.get_parent(n).unwrap(); let datatype = v.get_type(); @@ -96,9 +96,9 @@ impl ConstFoldPass { .run_validated_pass(hugr, reg, |hugr: &mut H, _| self.run_no_validate(hugr)) } - fn find_needed_nodes( + fn find_needed_nodes( &self, - results: &AnalysisResults, + results: &AnalysisResults>, root: Node, needed: &mut HashSet, ) { @@ -160,12 +160,6 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { struct ConstFoldContext<'a, H>(&'a mut H); -impl<'a, T: HugrView> AsRef for ConstFoldContext<'a, T> { - fn as_ref(&self) -> &hugr_core::Hugr { - self.0.base_hugr() - } -} - impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { fn value_from_opaque( &self, @@ -196,7 +190,7 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { }; // Returning the function body as a value, here, would be sufficient for inlining IndirectCall // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(self, node).ok()?; + let func = DescendantsGraph::>::try_new(self.hugr(), node).ok()?; Some(ValueHandle::new_const_hugr( node, &[], @@ -206,6 +200,10 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { } impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { + fn hugr(&self) -> &impl HugrView { + self.0 + } + type InterpretableVal = Value; fn interpret_leaf_op( diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 650912b0f..df6342df2 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -18,7 +18,9 @@ use hugr_core::{Hugr, HugrView, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader + HugrView { +pub trait DFContext: ConstLoader { + fn hugr(&self) -> &impl HugrView; + /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 455b328e5..f6184e9f1 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -45,7 +45,7 @@ impl Machine { context: C, in_values: impl IntoIterator)>, ) -> AnalysisResults { - let root = context.root(); + let root = context.hugr().root(); self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); run_datalog(context, self.0) @@ -75,38 +75,38 @@ pub(super) fn run_datalog>( lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value lattice node_in_value_row(Node, ValueRow); // 's inputs are - node(n) <-- for n in ctx.nodes(); + node(n) <-- for n in ctx.hugr().nodes(); - in_wire(n, p) <-- node(n), for (p,_) in ctx.in_value_types(*n); // Note, gets connected inports only - out_wire(n, p) <-- node(n), for (p,_) in ctx.out_value_types(*n); // (and likewise) + in_wire(n, p) <-- node(n), for (p,_) in ctx.hugr().in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in ctx.hugr().out_value_types(*n); // (and likewise) parent_of_node(parent, child) <-- - node(child), if let Some(parent) = ctx.get_parent(*child); + node(child), if let Some(parent) = ctx.hugr().get_parent(*child); - input_child(parent, input) <-- node(parent), if let Some([input, _output]) = ctx.get_io(*parent); - output_child(parent, output) <-- node(parent), if let Some([_input, output]) = ctx.get_io(*parent); + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = ctx.hugr().get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = ctx.hugr().get_io(*parent); // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); in_wire_value(n, ip, v) <-- in_wire(n, ip), - if let Some((m, op)) = ctx.single_linked_output(*n, *ip), + if let Some((m, op)) = ctx.hugr().single_linked_output(*n, *ip), out_wire_value(m, op, v); // Prepopulate in_wire_value from in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), - if let Some(sig) = ctx.signature(*n), + if let Some(sig) = ctx.hugr().signature(*n), if sig.input_ports().contains(p); // Assemble in_value_row from in_value's - node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); - node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.hugr().signature(*n); + node_in_value_row(n, ValueRow::single_known(ctx.hugr().signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); out_wire_value(n, p, v) <-- node(n), - let op_t = ctx.get_optype(*n), + let op_t = ctx.hugr().get_optype(*n), if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), @@ -115,7 +115,7 @@ pub(super) fn run_datalog>( // DFG relation dfg_node(Node); // is a `DFG` - dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); + dfg_node(n) <-- node(n), if ctx.hugr().get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), input_child(dfg, i), in_wire_value(dfg, p, v); @@ -128,13 +128,13 @@ pub(super) fn run_datalog>( // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), - if ctx.get_optype(*tl).is_tail_loop(), + if ctx.hugr().get_optype(*tl).is_tail_loop(), input_child(tl, i), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = ctx.hugr().get_optype(*tl).as_tail_loop(), input_child(tl, in_n), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node @@ -143,7 +143,7 @@ pub(super) fn run_datalog>( // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = ctx.hugr().get_optype(*tl).as_tail_loop(), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 @@ -154,17 +154,17 @@ pub(super) fn run_datalog>( // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); - conditional_node(n)<-- node(n), if ctx.get_optype(*n).is_conditional(); + conditional_node(n)<-- node(n), if ctx.hugr().get_optype(*n).is_conditional(); case_node(cond, i, case) <-- conditional_node(cond), - for (i, case) in ctx.children(*cond).enumerate(), - if ctx.get_optype(case).is_case(); + for (i, case) in ctx.hugr().children(*cond).enumerate(), + if ctx.hugr().get_optype(case).is_case(); // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), input_child(case, i_node), node_in_value_row(cond, in_row), - let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), + let conditional = ctx.hugr().get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -183,39 +183,39 @@ pub(super) fn run_datalog>( // CFG relation cfg_node(Node); // is a `CFG` - cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); + cfg_node(n) <-- node(n), if ctx.hugr().get_optype(*n).is_cfg(); // In `CFG` , basic block is reachable given our knowledge of predicates relation bb_reachable(Node, Node); - bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.hugr().children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), output_child(pred, pred_out), in_wire_value(pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in ctx.output_neighbours(*pred).enumerate(), + for (tag, bb) in ctx.hugr().output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), - if let Some(entry) = ctx.children(*cfg).next(), + if let Some(entry) = ctx.hugr().children(*cfg).next(), input_child(entry, i_node), in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of . relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.hugr().children(*cfg).nth(1); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), - for blk in ctx.children(*cfg), - if ctx.get_optype(blk).is_dataflow_block(), + for blk in ctx.hugr().children(*cfg), + if ctx.hugr().get_optype(blk).is_dataflow_block(), input_child(blk, inp); // Outputs of each reachable block propagated to successor block or CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- bb_reachable(cfg, pred), - if let Some(df_block) = ctx.get_optype(*pred).as_dataflow_block(), - for (succ_n, succ) in ctx.output_neighbours(*pred).enumerate(), + if let Some(df_block) = ctx.hugr().get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in ctx.hugr().output_neighbours(*pred).enumerate(), output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), @@ -226,8 +226,8 @@ pub(super) fn run_datalog>( relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), - if ctx.get_optype(*call).is_call(), - if let Some(func_defn) = ctx.static_source(*call); + if ctx.hugr().get_optype(*call).is_call(), + if let Some(func_defn) = ctx.hugr().static_source(*call); out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), @@ -245,7 +245,7 @@ pub(super) fn run_datalog>( .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - hugr: ctx, + ctx, out_wire_values, in_wire_value: all_results.in_wire_value, case_reachable: all_results.case_reachable, @@ -259,7 +259,7 @@ fn propagate_leaf_op( ins: &[PV], num_outs: usize, ) -> Option> { - match ctx.get_optype(n) { + match ctx.hugr().get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( @@ -282,10 +282,16 @@ fn propagate_leaf_op( OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant let const_node = ctx + .hugr() .single_linked_output(n, load_op.constant_port()) .unwrap() .0; - let const_val = ctx.get_optype(const_node).as_const().unwrap().value(); + let const_val = ctx + .hugr() + .get_optype(const_node) + .as_const() + .unwrap() + .value(); Some(ValueRow::single_known( 1, 0, @@ -295,6 +301,7 @@ fn propagate_leaf_op( OpType::LoadFunction(load_op) => { assert!(ins.is_empty()); // static edge let func_node = ctx + .hugr() .single_linked_output(n, load_op.function_port()) .unwrap() .0; diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 5cddc5354..d35b1e8fd 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,27 +2,27 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, PartialValue}; +use super::{AbstractValue, DFContext, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults { - pub(super) hugr: H, +pub struct AnalysisResults> { + pub(super) ctx: C, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, pub(super) bb_reachable: Vec<(Node, Node)>, pub(super) out_wire_values: HashMap>, } -impl AnalysisResults { +impl> AnalysisResults { /// Allows to use the [HugrView] contained within - pub fn hugr(&self) -> &H { - &self.hugr + pub fn hugr(&self) -> &(impl HugrView + '_) { + self.ctx.hugr() } - /// Discards the results, allowing to get back the [HugrView] within - pub fn into_hugr(self) -> H { - self.hugr + /// Discards the results, allowing to get back the [DFContext] + pub fn into_ctx(self) -> C { + self.ctx } /// Gets the lattice value computed for the given wire @@ -36,8 +36,8 @@ impl AnalysisResults { /// /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates(&self, node: Node) -> Option { - self.hugr.get_optype(node).as_tail_loop()?; - let [_, out] = self.hugr.get_io(node).unwrap(); + self.hugr().get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr().get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( self.in_wire_value .iter() @@ -54,9 +54,9 @@ impl AnalysisResults { /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, case: Node) -> Option { - self.hugr.get_optype(case).as_case()?; - let cond = self.hugr.get_parent(case)?; - self.hugr.get_optype(cond).as_conditional()?; + self.hugr().get_optype(case).as_case()?; + let cond = self.hugr().get_parent(case)?; + self.hugr().get_optype(cond).as_conditional()?; Some( self.case_reachable .iter() @@ -71,9 +71,9 @@ impl AnalysisResults { /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, bb: Node) -> Option { - let cfg = self.hugr.get_parent(bb)?; // Not really required...?? - self.hugr.get_optype(cfg).as_cfg()?; - let t = self.hugr.get_optype(bb); + let cfg = self.hugr().get_parent(bb)?; // Not really required...?? + self.hugr().get_optype(cfg).as_cfg()?; + let t = self.hugr().get_optype(bb); if !t.is_dataflow_block() && !t.is_exit_block() { return None; }; @@ -85,7 +85,7 @@ impl AnalysisResults { } } -impl AnalysisResults +impl> AnalysisResults where Value: From, { @@ -103,7 +103,7 @@ where pub fn try_read_wire_value(&self, w: Wire) -> Result> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self - .hugr + .hugr() .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e73cc8e4f..700ec591e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -25,13 +25,12 @@ impl AbstractValue for Void {} struct TestContext(H); -impl AsRef for TestContext { - fn as_ref(&self) -> &Hugr { - self.0.base_hugr() +impl ConstLoader for TestContext {} +impl DFContext for TestContext { + fn hugr(&self) -> &impl HugrView { + &self.0 } } -impl ConstLoader for TestContext {} -impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index f325b046a..94d7d0795 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -8,6 +8,8 @@ use super::{ConstLoader, DFContext}; /// values that are completely known (in the lattice `V`) rather than partially /// (e.g. no [PartialSum]s of more than one variant, no top/bottom) pub trait TotalContext: ConstLoader { + fn hugr(&self) -> &impl HugrView; + /// Representation of a (single, non-partial) value usable for interpretation type InterpretableVal: From + TryFrom>; @@ -21,7 +23,11 @@ pub trait TotalContext: ConstLoader { ) -> Vec<(OutgoingPort, PartialValue)>; } -impl + ConstLoader + HugrView> DFContext for T { +impl + ConstLoader> DFContext for T { + fn hugr(&self) -> &impl HugrView { + TotalContext::hugr(self) + } + fn interpret_leaf_op( &self, node: Node, From 522d1597605977fca220d0bd72d88ccf9c2ce17b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 12:24:01 +0000 Subject: [PATCH 186/281] Revert "Separate out DFContext::hugr(&self) -> &impl HugrView" This reverts commit 25f2bb74ab5c8f3aa66b5df9048fa6e959af89e0. --- hugr-passes/src/const_fold.rs | 18 +++--- hugr-passes/src/dataflow.rs | 4 +- hugr-passes/src/dataflow/datalog.rs | 75 ++++++++++------------- hugr-passes/src/dataflow/results.rs | 38 ++++++------ hugr-passes/src/dataflow/test.rs | 9 +-- hugr-passes/src/dataflow/total_context.rs | 8 +-- 6 files changed, 70 insertions(+), 82 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 6e9cbb37a..5db190a7f 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -70,7 +70,7 @@ impl ConstFoldPass { )) }) .collect::>(); - let hugr_mut = results.into_ctx().0; // and drop 'results' + let hugr_mut = results.into_hugr().0; // and drop 'results' for (n, inport, v) in wires_to_break { let parent = hugr_mut.get_parent(n).unwrap(); let datatype = v.get_type(); @@ -96,9 +96,9 @@ impl ConstFoldPass { .run_validated_pass(hugr, reg, |hugr: &mut H, _| self.run_no_validate(hugr)) } - fn find_needed_nodes( + fn find_needed_nodes( &self, - results: &AnalysisResults>, + results: &AnalysisResults, root: Node, needed: &mut HashSet, ) { @@ -160,6 +160,12 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { struct ConstFoldContext<'a, H>(&'a mut H); +impl<'a, T: HugrView> AsRef for ConstFoldContext<'a, T> { + fn as_ref(&self) -> &hugr_core::Hugr { + self.0.base_hugr() + } +} + impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { fn value_from_opaque( &self, @@ -190,7 +196,7 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { }; // Returning the function body as a value, here, would be sufficient for inlining IndirectCall // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(self.hugr(), node).ok()?; + let func = DescendantsGraph::>::try_new(self, node).ok()?; Some(ValueHandle::new_const_hugr( node, &[], @@ -200,10 +206,6 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { } impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { - fn hugr(&self) -> &impl HugrView { - self.0 - } - type InterpretableVal = Value; fn interpret_leaf_op( diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index df6342df2..650912b0f 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -18,9 +18,7 @@ use hugr_core::{Hugr, HugrView, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader { - fn hugr(&self) -> &impl HugrView; - +pub trait DFContext: ConstLoader + HugrView { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index f6184e9f1..455b328e5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -45,7 +45,7 @@ impl Machine { context: C, in_values: impl IntoIterator)>, ) -> AnalysisResults { - let root = context.hugr().root(); + let root = context.root(); self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); run_datalog(context, self.0) @@ -75,38 +75,38 @@ pub(super) fn run_datalog>( lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value lattice node_in_value_row(Node, ValueRow); // 's inputs are - node(n) <-- for n in ctx.hugr().nodes(); + node(n) <-- for n in ctx.nodes(); - in_wire(n, p) <-- node(n), for (p,_) in ctx.hugr().in_value_types(*n); // Note, gets connected inports only - out_wire(n, p) <-- node(n), for (p,_) in ctx.hugr().out_value_types(*n); // (and likewise) + in_wire(n, p) <-- node(n), for (p,_) in ctx.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in ctx.out_value_types(*n); // (and likewise) parent_of_node(parent, child) <-- - node(child), if let Some(parent) = ctx.hugr().get_parent(*child); + node(child), if let Some(parent) = ctx.get_parent(*child); - input_child(parent, input) <-- node(parent), if let Some([input, _output]) = ctx.hugr().get_io(*parent); - output_child(parent, output) <-- node(parent), if let Some([_input, output]) = ctx.hugr().get_io(*parent); + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = ctx.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = ctx.get_io(*parent); // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); in_wire_value(n, ip, v) <-- in_wire(n, ip), - if let Some((m, op)) = ctx.hugr().single_linked_output(*n, *ip), + if let Some((m, op)) = ctx.single_linked_output(*n, *ip), out_wire_value(m, op, v); // Prepopulate in_wire_value from in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), - if let Some(sig) = ctx.hugr().signature(*n), + if let Some(sig) = ctx.signature(*n), if sig.input_ports().contains(p); // Assemble in_value_row from in_value's - node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.hugr().signature(*n); - node_in_value_row(n, ValueRow::single_known(ctx.hugr().signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); + node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); out_wire_value(n, p, v) <-- node(n), - let op_t = ctx.hugr().get_optype(*n), + let op_t = ctx.get_optype(*n), if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), @@ -115,7 +115,7 @@ pub(super) fn run_datalog>( // DFG relation dfg_node(Node); // is a `DFG` - dfg_node(n) <-- node(n), if ctx.hugr().get_optype(*n).is_dfg(); + dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), input_child(dfg, i), in_wire_value(dfg, p, v); @@ -128,13 +128,13 @@ pub(super) fn run_datalog>( // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), - if ctx.hugr().get_optype(*tl).is_tail_loop(), + if ctx.get_optype(*tl).is_tail_loop(), input_child(tl, i), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = ctx.hugr().get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), input_child(tl, in_n), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node @@ -143,7 +143,7 @@ pub(super) fn run_datalog>( // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = ctx.hugr().get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 @@ -154,17 +154,17 @@ pub(super) fn run_datalog>( // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); - conditional_node(n)<-- node(n), if ctx.hugr().get_optype(*n).is_conditional(); + conditional_node(n)<-- node(n), if ctx.get_optype(*n).is_conditional(); case_node(cond, i, case) <-- conditional_node(cond), - for (i, case) in ctx.hugr().children(*cond).enumerate(), - if ctx.hugr().get_optype(case).is_case(); + for (i, case) in ctx.children(*cond).enumerate(), + if ctx.get_optype(case).is_case(); // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), input_child(case, i_node), node_in_value_row(cond, in_row), - let conditional = ctx.hugr().get_optype(*cond).as_conditional().unwrap(), + let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -183,39 +183,39 @@ pub(super) fn run_datalog>( // CFG relation cfg_node(Node); // is a `CFG` - cfg_node(n) <-- node(n), if ctx.hugr().get_optype(*n).is_cfg(); + cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); // In `CFG` , basic block is reachable given our knowledge of predicates relation bb_reachable(Node, Node); - bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.hugr().children(*cfg).next(); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), output_child(pred, pred_out), in_wire_value(pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in ctx.hugr().output_neighbours(*pred).enumerate(), + for (tag, bb) in ctx.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), - if let Some(entry) = ctx.hugr().children(*cfg).next(), + if let Some(entry) = ctx.children(*cfg).next(), input_child(entry, i_node), in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of . relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.hugr().children(*cfg).nth(1); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), - for blk in ctx.hugr().children(*cfg), - if ctx.hugr().get_optype(blk).is_dataflow_block(), + for blk in ctx.children(*cfg), + if ctx.get_optype(blk).is_dataflow_block(), input_child(blk, inp); // Outputs of each reachable block propagated to successor block or CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- bb_reachable(cfg, pred), - if let Some(df_block) = ctx.hugr().get_optype(*pred).as_dataflow_block(), - for (succ_n, succ) in ctx.hugr().output_neighbours(*pred).enumerate(), + if let Some(df_block) = ctx.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in ctx.output_neighbours(*pred).enumerate(), output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), @@ -226,8 +226,8 @@ pub(super) fn run_datalog>( relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), - if ctx.hugr().get_optype(*call).is_call(), - if let Some(func_defn) = ctx.hugr().static_source(*call); + if ctx.get_optype(*call).is_call(), + if let Some(func_defn) = ctx.static_source(*call); out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), @@ -245,7 +245,7 @@ pub(super) fn run_datalog>( .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - ctx, + hugr: ctx, out_wire_values, in_wire_value: all_results.in_wire_value, case_reachable: all_results.case_reachable, @@ -259,7 +259,7 @@ fn propagate_leaf_op( ins: &[PV], num_outs: usize, ) -> Option> { - match ctx.hugr().get_optype(n) { + match ctx.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( @@ -282,16 +282,10 @@ fn propagate_leaf_op( OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant let const_node = ctx - .hugr() .single_linked_output(n, load_op.constant_port()) .unwrap() .0; - let const_val = ctx - .hugr() - .get_optype(const_node) - .as_const() - .unwrap() - .value(); + let const_val = ctx.get_optype(const_node).as_const().unwrap().value(); Some(ValueRow::single_known( 1, 0, @@ -301,7 +295,6 @@ fn propagate_leaf_op( OpType::LoadFunction(load_op) => { assert!(ins.is_empty()); // static edge let func_node = ctx - .hugr() .single_linked_output(n, load_op.function_port()) .unwrap() .0; diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index d35b1e8fd..5cddc5354 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,27 +2,27 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, DFContext, PartialValue}; +use super::{AbstractValue, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults> { - pub(super) ctx: C, +pub struct AnalysisResults { + pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, pub(super) bb_reachable: Vec<(Node, Node)>, pub(super) out_wire_values: HashMap>, } -impl> AnalysisResults { +impl AnalysisResults { /// Allows to use the [HugrView] contained within - pub fn hugr(&self) -> &(impl HugrView + '_) { - self.ctx.hugr() + pub fn hugr(&self) -> &H { + &self.hugr } - /// Discards the results, allowing to get back the [DFContext] - pub fn into_ctx(self) -> C { - self.ctx + /// Discards the results, allowing to get back the [HugrView] within + pub fn into_hugr(self) -> H { + self.hugr } /// Gets the lattice value computed for the given wire @@ -36,8 +36,8 @@ impl> AnalysisResults { /// /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates(&self, node: Node) -> Option { - self.hugr().get_optype(node).as_tail_loop()?; - let [_, out] = self.hugr().get_io(node).unwrap(); + self.hugr.get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( self.in_wire_value .iter() @@ -54,9 +54,9 @@ impl> AnalysisResults { /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, case: Node) -> Option { - self.hugr().get_optype(case).as_case()?; - let cond = self.hugr().get_parent(case)?; - self.hugr().get_optype(cond).as_conditional()?; + self.hugr.get_optype(case).as_case()?; + let cond = self.hugr.get_parent(case)?; + self.hugr.get_optype(cond).as_conditional()?; Some( self.case_reachable .iter() @@ -71,9 +71,9 @@ impl> AnalysisResults { /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, bb: Node) -> Option { - let cfg = self.hugr().get_parent(bb)?; // Not really required...?? - self.hugr().get_optype(cfg).as_cfg()?; - let t = self.hugr().get_optype(bb); + let cfg = self.hugr.get_parent(bb)?; // Not really required...?? + self.hugr.get_optype(cfg).as_cfg()?; + let t = self.hugr.get_optype(bb); if !t.is_dataflow_block() && !t.is_exit_block() { return None; }; @@ -85,7 +85,7 @@ impl> AnalysisResults { } } -impl> AnalysisResults +impl AnalysisResults where Value: From, { @@ -103,7 +103,7 @@ where pub fn try_read_wire_value(&self, w: Wire) -> Result> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self - .hugr() + .hugr .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 700ec591e..e73cc8e4f 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -25,12 +25,13 @@ impl AbstractValue for Void {} struct TestContext(H); -impl ConstLoader for TestContext {} -impl DFContext for TestContext { - fn hugr(&self) -> &impl HugrView { - &self.0 +impl AsRef for TestContext { + fn as_ref(&self) -> &Hugr { + self.0.base_hugr() } } +impl ConstLoader for TestContext {} +impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 94d7d0795..f325b046a 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -8,8 +8,6 @@ use super::{ConstLoader, DFContext}; /// values that are completely known (in the lattice `V`) rather than partially /// (e.g. no [PartialSum]s of more than one variant, no top/bottom) pub trait TotalContext: ConstLoader { - fn hugr(&self) -> &impl HugrView; - /// Representation of a (single, non-partial) value usable for interpretation type InterpretableVal: From + TryFrom>; @@ -23,11 +21,7 @@ pub trait TotalContext: ConstLoader { ) -> Vec<(OutgoingPort, PartialValue)>; } -impl + ConstLoader> DFContext for T { - fn hugr(&self) -> &impl HugrView { - TotalContext::hugr(self) - } - +impl + ConstLoader + HugrView> DFContext for T { fn interpret_leaf_op( &self, node: Node, From ebebb181a66c884715773e8682f80340ba53ab01 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 12:44:07 +0000 Subject: [PATCH 187/281] RIP TotalContext --- hugr-passes/src/const_fold.rs | 51 +++++++++++++-------- hugr-passes/src/dataflow.rs | 3 -- hugr-passes/src/dataflow/total_context.rs | 54 ----------------------- 3 files changed, 32 insertions(+), 76 deletions(-) delete mode 100644 hugr-passes/src/dataflow/total_context.rs diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 5db190a7f..97bf0133c 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -11,15 +11,18 @@ use hugr_core::{ hugrmut::HugrMut, views::{DescendantsGraph, ExtractHugr, HierarchyView}, }, - ops::{constant::OpaqueValue, handle::FuncID, Const, ExtensionOp, LoadConstant, Value}, + ops::{ + constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, + Value, + }, types::{EdgeKind, TypeArg}, - HugrView, IncomingPort, Node, OutgoingPort, Wire, + HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire, }; use value_handle::ValueHandle; use crate::{ dataflow::{ - AnalysisResults, ConstLoader, Machine, PartialValue, TailLoopTermination, TotalContext, + AnalysisResults, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination, }, validation::{ValidatePassError, ValidationLevel}, }; @@ -205,25 +208,35 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { } } -impl<'a, H: HugrView> TotalContext for ConstFoldContext<'a, H> { - type InterpretableVal = Value; - +impl<'a, H: HugrView> DFContext for ConstFoldContext<'a, H> { fn interpret_leaf_op( &self, - n: Node, + node: Node, op: &ExtensionOp, - ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort, PartialValue)> { - op.constant_fold(&ins).map_or(Vec::new(), |outs| { - outs.into_iter() - .map(|(p, v)| { - ( - p, - self.value_from_const(n, &v), // Hmmm, should (at least) also key by p - ) - }) - .collect() - }) + ins: &[PartialValue], + outs: &mut [PartialValue], + ) { + let sig = op.signature(); + let known_ins = sig + .input_types() + .iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + let v = match pv { + PartialValue::Bottom | PartialValue::Top => None, + PartialValue::Value(v) => Some(v.clone().into()), + PartialValue::PartialSum(ps) => { + Value::try_from(ps.clone().try_into_value::(ty).ok()?).ok() + } + }?; + Some((IncomingPort::from(i), v)) + }) + .collect::>(); + for (p, v) in op.constant_fold(&known_ins).unwrap_or(Vec::new()) { + // Hmmm, we should (at least) key the value also by p + outs[p.index()] = self.value_from_const(node, &v); + } } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 650912b0f..c960fb7de 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -100,8 +100,5 @@ fn traverse_value( } } -mod total_context; -pub use total_context::TotalContext; - #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs deleted file mode 100644 index f325b046a..000000000 --- a/hugr-passes/src/dataflow/total_context.rs +++ /dev/null @@ -1,54 +0,0 @@ -use hugr_core::ops::{DataflowOpTrait, ExtensionOp}; -use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; - -use super::partial_value::{AbstractValue, PartialValue, Sum}; -use super::{ConstLoader, DFContext}; - -/// A simpler interface like [DFContext] but where the context only cares about -/// values that are completely known (in the lattice `V`) rather than partially -/// (e.g. no [PartialSum]s of more than one variant, no top/bottom) -pub trait TotalContext: ConstLoader { - /// Representation of a (single, non-partial) value usable for interpretation - type InterpretableVal: From + TryFrom>; - - /// Interpret an (extension) operation given total values for some of the in-ports - /// `ins` will be a non-empty slice with distinct [IncomingPort]s. - fn interpret_leaf_op( - &self, - node: Node, - e: &ExtensionOp, - ins: &[(IncomingPort, Self::InterpretableVal)], - ) -> Vec<(OutgoingPort, PartialValue)>; -} - -impl + ConstLoader + HugrView> DFContext for T { - fn interpret_leaf_op( - &self, - node: Node, - op: &ExtensionOp, - ins: &[PartialValue], - outs: &mut [PartialValue], - ) { - let sig = op.signature(); - let known_ins = sig - .input_types() - .iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - let v = match pv { - PartialValue::Bottom | PartialValue::Top => None, - PartialValue::Value(v) => Some(v.clone().into()), - PartialValue::PartialSum(ps) => T::InterpretableVal::try_from( - ps.clone().try_into_value::(ty).ok()?, - ) - .ok(), - }?; - Some((IncomingPort::from(i), v)) - }) - .collect::>(); - for (p, v) in self.interpret_leaf_op(node, op, &known_ins) { - outs[p.index()] = v; - } - } -} From 46944f86ab8ca14680b1ad8d5fde1b877b353707 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 12:58:14 +0000 Subject: [PATCH 188/281] Two-stage transform means we don't need to extract HugrView from results --- hugr-passes/src/const_fold.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 97bf0133c..4103412a5 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -73,19 +73,19 @@ impl ConstFoldPass { )) }) .collect::>(); - let hugr_mut = results.into_hugr().0; // and drop 'results' + for (n, inport, v) in wires_to_break { - let parent = hugr_mut.get_parent(n).unwrap(); + let parent = hugr.get_parent(n).unwrap(); let datatype = v.get_type(); // We could try hash-consing identical Consts, but not ATM - let cst = hugr_mut.add_node_with_parent(parent, Const::new(v)); - let lcst = hugr_mut.add_node_with_parent(parent, LoadConstant { datatype }); - hugr_mut.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); - hugr_mut.disconnect(n, inport); - hugr_mut.connect(lcst, OutgoingPort::from(0), n, inport); + let cst = hugr.add_node_with_parent(parent, Const::new(v)); + let lcst = hugr.add_node_with_parent(parent, LoadConstant { datatype }); + hugr.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); + hugr.disconnect(n, inport); + hugr.connect(lcst, OutgoingPort::from(0), n, inport); } for n in remove_nodes { - hugr_mut.remove_node(n); + hugr.remove_node(n); } Ok(()) } From 0b005e4c566a2eda38861ae0f76ebe83f1592767 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 14:03:40 +0000 Subject: [PATCH 189/281] Go back to Deref, w/ type DFContext::View. --- hugr-passes/src/const_fold.rs | 13 ++++++++----- hugr-passes/src/dataflow.rs | 4 +++- hugr-passes/src/dataflow/results.rs | 19 +++++++------------ hugr-passes/src/dataflow/test.rs | 11 +++++++---- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 4103412a5..0137eb5e8 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -101,7 +101,7 @@ impl ConstFoldPass { fn find_needed_nodes( &self, - results: &AnalysisResults, + results: &AnalysisResults>, root: Node, needed: &mut HashSet, ) { @@ -163,9 +163,10 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { struct ConstFoldContext<'a, H>(&'a mut H); -impl<'a, T: HugrView> AsRef for ConstFoldContext<'a, T> { - fn as_ref(&self) -> &hugr_core::Hugr { - self.0.base_hugr() +impl<'a, H: HugrView> std::ops::Deref for ConstFoldContext<'a, H> { + type Target = H; + fn deref(&self) -> &H { + self.0 } } @@ -199,7 +200,7 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { }; // Returning the function body as a value, here, would be sufficient for inlining IndirectCall // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(self, node).ok()?; + let func = DescendantsGraph::>::try_new(&**self, node).ok()?; Some(ValueHandle::new_const_hugr( node, &[], @@ -209,6 +210,8 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { } impl<'a, H: HugrView> DFContext for ConstFoldContext<'a, H> { + type View = H; + fn interpret_leaf_op( &self, node: Node, diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index c960fb7de..bab01b9a7 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -18,7 +18,9 @@ use hugr_core::{Hugr, HugrView, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader + HugrView { +pub trait DFContext: ConstLoader + std::ops::Deref { + type View: HugrView; + /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 5cddc5354..744a635c9 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,27 +2,22 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, PartialValue}; +use super::{AbstractValue, DFContext, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults { - pub(super) hugr: H, +pub struct AnalysisResults> { + pub(super) hugr: C, // TODO: Rename pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, pub(super) bb_reachable: Vec<(Node, Node)>, pub(super) out_wire_values: HashMap>, } -impl AnalysisResults { +impl> AnalysisResults { /// Allows to use the [HugrView] contained within - pub fn hugr(&self) -> &H { - &self.hugr - } - - /// Discards the results, allowing to get back the [HugrView] within - pub fn into_hugr(self) -> H { - self.hugr + pub fn hugr(&self) -> &C::View { + &*self.hugr } /// Gets the lattice value computed for the given wire @@ -85,7 +80,7 @@ impl AnalysisResults { } } -impl AnalysisResults +impl> AnalysisResults where Value: From, { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e73cc8e4f..c7d4e7b7e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -25,13 +25,16 @@ impl AbstractValue for Void {} struct TestContext(H); -impl AsRef for TestContext { - fn as_ref(&self) -> &Hugr { - self.0.base_hugr() +impl std::ops::Deref for TestContext { + type Target = H; + fn deref(&self) -> &H { + &self.0 } } impl ConstLoader for TestContext {} -impl DFContext for TestContext {} +impl DFContext for TestContext { + type View = H; +} // This allows testing creation of tuple/sum Values (only) impl From for Value { From c4224fee2ea25d11a4cc626c690efc6e94ab60d1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 19:42:47 +0000 Subject: [PATCH 190/281] clippy+doc --- hugr-passes/src/const_fold.rs | 4 ++-- hugr-passes/src/const_fold/value_handle.rs | 4 ++-- hugr-passes/src/dataflow.rs | 2 ++ hugr-passes/src/dataflow/results.rs | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 0137eb5e8..d4cbe1bab 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -194,7 +194,7 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { } fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { - if type_args.len() > 0 { + if !type_args.is_empty() { // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) return None; }; @@ -236,7 +236,7 @@ impl<'a, H: HugrView> DFContext for ConstFoldContext<'a, H> { Some((IncomingPort::from(i), v)) }) .collect::>(); - for (p, v) in op.constant_fold(&known_ins).unwrap_or(Vec::new()) { + for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { // Hmmm, we should (at least) key the value also by p outs[p.index()] = self.value_from_const(node, &v); } diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index 6c8724c14..eb92aa018 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -110,10 +110,10 @@ impl From for Value { match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable(_, Either::Left(val)) => Value::Extension { - e: Arc::unwrap_or_clone(val), + e: Arc::try_unwrap(val).unwrap_or_else(|a| a.as_ref().clone()), }, ValueHandle::Unhashable(_, Either::Right(hugr)) => { - Value::function(Arc::unwrap_or_clone(hugr)) + Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) .unwrap() } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index bab01b9a7..69942eafc 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -19,6 +19,8 @@ use hugr_core::{Hugr, HugrView, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: ConstLoader + std::ops::Deref { + /// Type of view contained within this context. (Ideally we'd constrain + /// by `std::ops::Deref> { impl> AnalysisResults { /// Allows to use the [HugrView] contained within pub fn hugr(&self) -> &C::View { - &*self.hugr + &self.hugr } /// Gets the lattice value computed for the given wire From 87eb700565aba68c29feb26aeef699c48586bbd9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 19:52:40 +0000 Subject: [PATCH 191/281] And back to Deref, should allow using a region view not the whole Hugr --- hugr-passes/src/dataflow.rs | 6 +++++- hugr-passes/src/dataflow/datalog.rs | 4 ++-- hugr-passes/src/dataflow/results.rs | 33 +++++++++++++++++------------ hugr-passes/src/dataflow/test.rs | 11 ++++++---- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index c960fb7de..69942eafc 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -18,7 +18,11 @@ use hugr_core::{Hugr, HugrView, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader + HugrView { +pub trait DFContext: ConstLoader + std::ops::Deref { + /// Type of view contained within this context. (Ideally we'd constrain + /// by `std::ops::Deref>( .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - hugr: ctx, + ctx, out_wire_values, in_wire_value: all_results.in_wire_value, case_reachable: all_results.case_reachable, @@ -308,7 +308,7 @@ fn propagate_leaf_op( } OpType::ExtensionOp(e) => { // Interpret op. - let init = if ins.iter().contains(&PartialValue::Bottom) { + let init = if ins.iter().contains(&PartialValue::Bottom) { // So far we think one or more inputs can't happen. // So, don't pollute outputs with Top, and wait for better knowledge of inputs. PartialValue::Bottom diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index f457ef68c..b18a3c704 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,19 +2,24 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, PartialValue}; +use super::{AbstractValue, DFContext, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults { - pub(super) hugr: H, +pub struct AnalysisResults> { + pub(super) ctx: C, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, pub(super) bb_reachable: Vec<(Node, Node)>, pub(super) out_wire_values: HashMap>, } -impl AnalysisResults { +impl> AnalysisResults { + /// Allows to use the [HugrView] contained within + pub fn hugr(&self) -> &C::View { + &self.ctx + } + /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -26,8 +31,8 @@ impl AnalysisResults { /// /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates(&self, node: Node) -> Option { - self.hugr.get_optype(node).as_tail_loop()?; - let [_, out] = self.hugr.get_io(node).unwrap(); + self.hugr().get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr().get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( self.in_wire_value .iter() @@ -44,9 +49,9 @@ impl AnalysisResults { /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, case: Node) -> Option { - self.hugr.get_optype(case).as_case()?; - let cond = self.hugr.get_parent(case)?; - self.hugr.get_optype(cond).as_conditional()?; + self.hugr().get_optype(case).as_case()?; + let cond = self.hugr().get_parent(case)?; + self.hugr().get_optype(cond).as_conditional()?; Some( self.case_reachable .iter() @@ -61,9 +66,9 @@ impl AnalysisResults { /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, bb: Node) -> Option { - let cfg = self.hugr.get_parent(bb)?; // Not really required...?? - self.hugr.get_optype(cfg).as_cfg()?; - let t = self.hugr.get_optype(bb); + let cfg = self.hugr().get_parent(bb)?; // Not really required...?? + self.hugr().get_optype(cfg).as_cfg()?; + let t = self.hugr().get_optype(bb); if !t.is_dataflow_block() && !t.is_exit_block() { return None; }; @@ -75,7 +80,7 @@ impl AnalysisResults { } } -impl AnalysisResults +impl> AnalysisResults where Value: From, { @@ -93,7 +98,7 @@ where pub fn try_read_wire_value(&self, w: Wire) -> Result> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self - .hugr + .hugr() .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e73cc8e4f..c7d4e7b7e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -25,13 +25,16 @@ impl AbstractValue for Void {} struct TestContext(H); -impl AsRef for TestContext { - fn as_ref(&self) -> &Hugr { - self.0.base_hugr() +impl std::ops::Deref for TestContext { + type Target = H; + fn deref(&self) -> &H { + &self.0 } } impl ConstLoader for TestContext {} -impl DFContext for TestContext {} +impl DFContext for TestContext { + type View = H; +} // This allows testing creation of tuple/sum Values (only) impl From for Value { From a49221b3b172ac3bb514a763d072c66d8e1747cd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 10:49:15 +0000 Subject: [PATCH 192/281] fix doclink --- hugr-passes/src/dataflow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 69942eafc..0b7bf03ed 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -47,7 +47,7 @@ pub trait DFContext: ConstLoader + std::ops::Deref { pub trait ConstLoader { /// Produces an abstract value from a constant. The default impl /// traverses the constant [Value] to its leaves ([Value::Extension] and [Value::Function]), - /// converts these using [Self::value_from_custom_const] and [Self::value_from_const_hugr], + /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { traverse_value(self, n, &mut Vec::new(), cst) From 71ea55d1e5e7970d75d0ea57c34a41c64557806a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 10:40:21 +0000 Subject: [PATCH 193/281] PartialSum::try_into_value also uses Option<...> as error-type --- hugr-passes/src/dataflow/partial_value.rs | 59 ++++++++++++----------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 5b5695dd0..8cd5c12fb 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -141,41 +141,37 @@ impl PartialSum { self.0.contains_key(&tag) } - /// Turns this instance into a [Sum] if it has exactly one possible tag, - /// otherwise failing and returning itself back unmodified (also if there is another - /// error, e.g. this instance is not described by `typ`). - // ALAN is this too parametric? Should we fix V2 == Value? Is the 'Self' error useful (no?) + /// Turns this instance into a [Sum] of some target value type `V2`, + /// *if* this PartialSum has exactly one possible tag. + /// + /// # Errors + /// `None` if this PartialSum had multiple possible tags; or, if there was a single + /// tag, but `typ` was not a [TypeEnum::Sum] supporting that tag and containing no + /// row variables within that variant and of the correct number of variants + /// `Some(e)` if none of the error conditions above applied, but there was an error + /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] pub fn try_into_value + TryFrom>>( self, typ: &Type, - ) -> Result, Self> { - let Ok((k, v)) = self.0.iter().exactly_one() else { - Err(self)? - }; - + ) -> Result, Option<>>::Error>> { + let (k, v) = self.0.iter().exactly_one().map_err(|_| None)?; let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(self)? + Err(None)? }; let Some(r) = st.get_variant(*k) else { - Err(self)? - }; - let Ok(r) = TypeRow::try_from(r.clone()) else { - Err(self)? + Err(None)? }; + let r: TypeRow = r.clone().try_into().map_err(|_| None)?; if v.len() != r.len() { - return Err(self); - } - match zip_eq(v, r.iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - { - Ok(values) => Ok(Sum { - tag: *k, - values, - st: st.clone(), - }), - Err(_) => Err(self), + return Err(None); } + Ok(Sum { + tag: *k, + values: zip_eq(v, r.iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>()?, + st: st.clone(), + }) } } @@ -301,8 +297,13 @@ impl PartialValue { } } - /// Extracts a value (in any representation supporting both leaf values and sums) - // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? + /// Turns this instance into a target value type `V2` if it is a single value, + /// or a [PartialValue::PartialSum] convertible by [PartialSum::try_into_value]. + /// + /// # Errors + /// + /// `None` if this is [Bottom](PartialValue::Bottom) or [Top](PartialValue::Top), + /// otherwise as per [PartialSum::try_into_value] pub fn try_into_value + TryFrom>>( self, typ: &Type, @@ -310,7 +311,7 @@ impl PartialValue { match self { Self::Value(v) => Ok(V2::from(v.clone())), Self::PartialSum(ps) => { - let v = ps.try_into_value(typ).map_err(|_| None)?; + let v = ps.try_into_value(typ)?; V2::try_from(v).map_err(Some) } _ => Err(None), From ea9db2e964e6e9345a4fab747fd0121d5f6f292a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 11:30:02 +0000 Subject: [PATCH 194/281] Proper errors from try_into_value, Option from try_read_wire_value --- hugr-passes/src/dataflow/partial_value.rs | 66 ++++++++++++++++------- hugr-passes/src/dataflow/results.rs | 13 +++-- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 8cd5c12fb..b03e2d842 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -6,6 +6,7 @@ use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +use thiserror::Error; /// Trait for an underlying domain of abstract values which can form the *elements* of a /// [PartialValue] and thus be used in dataflow analysis. @@ -150,31 +151,57 @@ impl PartialSum { /// row variables within that variant and of the correct number of variants /// `Some(e)` if none of the error conditions above applied, but there was an error /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] - pub fn try_into_value + TryFrom>>( + pub fn try_into_value + TryFrom, Error=E>>( self, typ: &Type, - ) -> Result, Option<>>::Error>> { - let (k, v) = self.0.iter().exactly_one().map_err(|_| None)?; - let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(None)? - }; - let Some(r) = st.get_variant(*k) else { - Err(None)? + ) -> Result, ExtractValueError> { + let Ok((k, v)) = self.0.iter().exactly_one() else { + return Err(ExtractValueError::MultipleVariants(self)); }; - let r: TypeRow = r.clone().try_into().map_err(|_| None)?; - if v.len() != r.len() { - return Err(None); + if let TypeEnum::Sum(st) = typ.as_type_enum() { + if let Some(r) = st.get_variant(*k) { + if let Ok(r) = TypeRow::try_from(r.clone()) { + if v.len() == r.len() { + return Ok(Sum { + tag: *k, + values: zip_eq(v, r.iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>()?, + st: st.clone(), + }); + } + } + } } - Ok(Sum { + Err(ExtractValueError::BadSumType { + typ: typ.clone(), tag: *k, - values: zip_eq(v, r.iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>()?, - st: st.clone(), + num_elements: v.len(), }) } } +#[derive(Clone, Debug, PartialEq, Eq, Error)] +#[allow(missing_docs)] +pub enum ExtractValueError { + #[error("PartialSum value had multiple possible tags: {0}")] + MultipleVariants(PartialSum), + #[error("Value contained `Bottom`")] + ValueIsBottom, + #[error("Value contained `Top`")] + ValueIsTop, + #[error("Could not convert element from abstract value into concrete: {0}")] + CouldNotConvert(V, #[source] E), + #[error("Could not build Sum from concrete element values")] + CouldNotBuildSum(#[source] E), + #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] + BadSumType { + typ: Type, + tag: usize, + num_elements: usize, + }, +} + impl PartialSum { /// If this Sum might have the specified `tag`, get the elements inside that tag. pub fn variant_values(&self, variant: usize) -> Option>> { @@ -307,14 +334,15 @@ impl PartialValue { pub fn try_into_value + TryFrom>>( self, typ: &Type, - ) -> Result>>::Error>> { + ) -> Result>>::Error>> { match self { Self::Value(v) => Ok(V2::from(v.clone())), Self::PartialSum(ps) => { let v = ps.try_into_value(typ)?; - V2::try_from(v).map_err(Some) + V2::try_from(v).map_err(ExtractValueError::CouldNotBuildSum) } - _ => Err(None), + Self::Top => Err(ExtractValueError::ValueIsTop), + Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index b18a3c704..f51c6353e 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, DFContext, PartialValue}; +use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). @@ -89,20 +89,23 @@ where /// a [Sum](PartialValue::PartialSum) with a single known tag.) /// /// # Errors - /// `None` if the analysis did not result in a single value on that wire - /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] + /// `None` if the analysis did not produce a result for that wire + /// `Some(e)` if conversion to a [Value] failed with error `e` /// /// # Panics /// /// If a [Type](hugr_core::types::Type) for the specified wire could not be extracted from the Hugr - pub fn try_read_wire_value(&self, w: Wire) -> Result> { + pub fn try_read_wire_value( + &self, + w: Wire, + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr() .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - v.try_into_value(&typ) + v.try_into_value(&typ).map_err(Some) } } From df3152356bc7a5657ba5289c430dde72599e320c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 11:53:41 +0000 Subject: [PATCH 195/281] fmt --- hugr-passes/src/dataflow/partial_value.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index b03e2d842..648034400 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -151,7 +151,7 @@ impl PartialSum { /// row variables within that variant and of the correct number of variants /// `Some(e)` if none of the error conditions above applied, but there was an error /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] - pub fn try_into_value + TryFrom, Error=E>>( + pub fn try_into_value + TryFrom, Error = E>>( self, typ: &Type, ) -> Result, ExtractValueError> { From e78c006931fb007734b80335cc9a753371d64814 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 11:54:29 +0000 Subject: [PATCH 196/281] ...and fix by using try_into_value(!)... --- hugr-passes/src/const_fold.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index d4cbe1bab..e7eb382d5 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -13,7 +13,6 @@ use hugr_core::{ }, ops::{ constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - Value, }, types::{EdgeKind, TypeArg}, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire, @@ -226,14 +225,7 @@ impl<'a, H: HugrView> DFContext for ConstFoldContext<'a, H> { .enumerate() .zip(ins.iter()) .filter_map(|((i, ty), pv)| { - let v = match pv { - PartialValue::Bottom | PartialValue::Top => None, - PartialValue::Value(v) => Some(v.clone().into()), - PartialValue::PartialSum(ps) => { - Value::try_from(ps.clone().try_into_value::(ty).ok()?).ok() - } - }?; - Some((IncomingPort::from(i), v)) + Some((IncomingPort::from(i), pv.clone().try_into_value(ty).ok()?)) }) .collect::>(); for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { From a490874e98fdc683aa890388de0fb1b49b5a270f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 12:23:01 +0000 Subject: [PATCH 197/281] Add test running on region --- hugr-passes/src/dataflow/test.rs | 55 ++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index c7d4e7b7e..80bd484ef 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,6 +1,8 @@ use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; +use hugr_core::ops::handle::DfgID; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -454,3 +456,56 @@ fn test_call( assert_eq!(res0, out); assert_eq!(res1, out); } + +#[test] +fn test_region() { + let mut builder = + DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T;2])).unwrap(); + let [in_w] = builder.input_wires_arr(); + let cst_w = builder.add_load_const(Value::false_val()); + let nested = builder + .dfg_builder(Signature::new_endo(type_row![BOOL_T; 2]), [in_w, cst_w]) + .unwrap(); + let nested_ins = nested.input_wires(); + let nested = nested.finish_with_outputs(nested_ins).unwrap(); + let hugr = builder + .finish_prelude_hugr_with_outputs(nested.outputs()) + .unwrap(); + let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); + let whole_hugr_results = Machine::default().run(TestContext(&hugr), [(0.into(), pv_true())]); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(pv_false()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 1)), + Some(pv_false()) + ); + + let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); + // Do not provide a value on the second input (constant false in the whole hugr, above) + let sub_hugr_results = Machine::default().run(TestContext(subview), [(0.into(), pv_true())]); + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + // TODO this should really be `Top` - safety says we have to assume it could be anything, not that it can't happen + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(PartialValue::Bottom) + ); + for w in [0, 1] { + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(hugr.root(), w)), + None + ); + } +} From 3af39aa0a133f41a7c5bf38e027982722efb5b01 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 12:42:06 +0000 Subject: [PATCH 198/281] fix: provide PartialValue::Top for unspecified Hugr inputs --- hugr-passes/src/dataflow/datalog.rs | 20 ++++++++++++++++++++ hugr-passes/src/dataflow/test.rs | 3 +-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 4fc54b511..facab8595 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -48,6 +48,26 @@ impl Machine { let root = context.root(); self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis + // (Consider: for a conditional that selects *either* the unknown input *or* value V, + // analysis must produce Top == we-know-nothing, not `V` !) + let mut have_inputs = + vec![false; context.signature(root).unwrap_or_default().input_count()]; + self.0.iter().for_each(|(n, p, _)| { + if n == &root { + if let Some(e) = have_inputs.get_mut(p.index()) { + *e = true; + } + } + }); + for (i, b) in have_inputs.into_iter().enumerate() { + if !b { + self.0 + .push((root, IncomingPort::from(i), PartialValue::Top)); + } + } + // Note/TODO, if analysis is running on a subregion then we should do similar + // for any nonlocal edges providing values from outside the region. run_datalog(context, self.0) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 80bd484ef..8ec0f9dee 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -497,10 +497,9 @@ fn test_region() { sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) ); - // TODO this should really be `Top` - safety says we have to assume it could be anything, not that it can't happen assert_eq!( sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)), - Some(PartialValue::Bottom) + Some(PartialValue::Top) ); for w in [0, 1] { assert_eq!( From dc159993ec4567a21708907f4d369bfb91679d31 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 13:31:11 +0000 Subject: [PATCH 199/281] try_into_value allows TryFrom by giving ExtractValueError *2* errortype params --- hugr-passes/src/dataflow/partial_value.rs | 17 +++++++++-------- hugr-passes/src/dataflow/results.rs | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 648034400..f601af86b 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -151,10 +151,10 @@ impl PartialSum { /// row variables within that variant and of the correct number of variants /// `Some(e)` if none of the error conditions above applied, but there was an error /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] - pub fn try_into_value + TryFrom, Error = E>>( + pub fn try_into_value + TryFrom, Error = SE>>( self, typ: &Type, - ) -> Result, ExtractValueError> { + ) -> Result, ExtractValueError> { let Ok((k, v)) = self.0.iter().exactly_one() else { return Err(ExtractValueError::MultipleVariants(self)); }; @@ -183,7 +183,7 @@ impl PartialSum { #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] @@ -191,9 +191,9 @@ pub enum ExtractValueError { #[error("Value contained `Top`")] ValueIsTop, #[error("Could not convert element from abstract value into concrete: {0}")] - CouldNotConvert(V, #[source] E), + CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] - CouldNotBuildSum(#[source] E), + CouldNotBuildSum(#[source] SE), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -331,12 +331,13 @@ impl PartialValue { /// /// `None` if this is [Bottom](PartialValue::Bottom) or [Top](PartialValue::Top), /// otherwise as per [PartialSum::try_into_value] - pub fn try_into_value + TryFrom>>( + pub fn try_into_value + TryFrom, Error = SE>>( self, typ: &Type, - ) -> Result>>::Error>> { + ) -> Result> { match self { - Self::Value(v) => Ok(V2::from(v.clone())), + Self::Value(v) => V2::try_from(v.clone()) + .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), Self::PartialSum(ps) => { let v = ps.try_into_value(typ)?; V2::try_from(v).map_err(ExtractValueError::CouldNotBuildSum) diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index f51c6353e..713900acc 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -98,7 +98,7 @@ where pub fn try_read_wire_value( &self, w: Wire, - ) -> Result>> { + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr() From 2d81264a3e5287230c975dc27ac2602149114ad9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 13:38:49 +0000 Subject: [PATCH 200/281] improve docs --- hugr-passes/src/dataflow/partial_value.rs | 23 +++++++++++++---------- hugr-passes/src/dataflow/results.rs | 2 +- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f601af86b..12f1733d3 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -142,15 +142,14 @@ impl PartialSum { self.0.contains_key(&tag) } - /// Turns this instance into a [Sum] of some target value type `V2`, + /// Turns this instance into a [Sum] of some "concrete" value type `V2`, /// *if* this PartialSum has exactly one possible tag. /// /// # Errors - /// `None` if this PartialSum had multiple possible tags; or, if there was a single - /// tag, but `typ` was not a [TypeEnum::Sum] supporting that tag and containing no - /// row variables within that variant and of the correct number of variants - /// `Some(e)` if none of the error conditions above applied, but there was an error - /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] + /// + /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] + /// supporting the single possible tag with the correct number of elements and no row variables; + /// or if converting a child element failed via [PartialValue::try_into_value]. pub fn try_into_value + TryFrom, Error = SE>>( self, typ: &Type, @@ -181,6 +180,8 @@ impl PartialSum { } } +/// An error converting a [PartialValue] or [PartialSum] into a concrete value type +/// via [PartialValue::try_into_value] or [PartialSum::try_into_value] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] pub enum ExtractValueError { @@ -324,13 +325,15 @@ impl PartialValue { } } - /// Turns this instance into a target value type `V2` if it is a single value, - /// or a [PartialValue::PartialSum] convertible by [PartialSum::try_into_value]. + /// Turns this instance into some "concrete" value type `V2`, *if* it is a single value, + /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by + /// [PartialSum::try_into_value]. /// /// # Errors /// - /// `None` if this is [Bottom](PartialValue::Bottom) or [Top](PartialValue::Top), - /// otherwise as per [PartialSum::try_into_value] + /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) + /// that could not be converted into a [Sum] by [PartialSum::try_into_value] (e.g. if `typ` is + /// incorrect), or if that [Sum] could not be converted into a `V2`. pub fn try_into_value + TryFrom, Error = SE>>( self, typ: &Type, diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 713900acc..c1e554154 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -90,7 +90,7 @@ where /// /// # Errors /// `None` if the analysis did not produce a result for that wire - /// `Some(e)` if conversion to a [Value] failed with error `e` + /// `Some(e)` if conversion to a [Value] failed with error `e`, see [PartialValue::try_into_value] /// /// # Panics /// From b61d2520231fd29ce5fb23dd489abf64726c7401 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 14:10:36 +0000 Subject: [PATCH 201/281] Parametrize Machine::try_read_wire_value the same way --- hugr-passes/src/dataflow/results.rs | 24 +++++++++++------------- hugr-passes/src/dataflow/test.rs | 12 ++++++------ 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c1e554154..6c90e33b3 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; -use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; +use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialValue}; +use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialValue, Sum}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). @@ -78,27 +78,25 @@ impl> AnalysisResults { .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), ) } -} -impl> AnalysisResults -where - Value: From, -{ - /// Reads a [Value] from an output wire, if the lattice value computed for it can be turned - /// into one. (The lattice value must be either a single [Value](PartialValue::Value) or - /// a [Sum](PartialValue::PartialSum) with a single known tag.) + /// Reads a concrete representation of the value on an output wire, if the lattice value + /// computed for the wire can be turned into such. (The lattice value must be either a + /// [PartialValue::Value] or a [PartialValue::PartialSum] with a single possible tag.) /// /// # Errors /// `None` if the analysis did not produce a result for that wire - /// `Some(e)` if conversion to a [Value] failed with error `e`, see [PartialValue::try_into_value] + /// `Some(e)` if conversion to a concrete value failed with error `e`, see [PartialValue::try_into_value] /// /// # Panics /// /// If a [Type](hugr_core::types::Type) for the specified wire could not be extracted from the Hugr - pub fn try_read_wire_value( + pub fn try_read_wire_value( &self, w: Wire, - ) -> Result>> { + ) -> Result>> + where + V2: TryFrom + TryFrom, Error = SE>, + { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr() diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 8ec0f9dee..057242d03 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -67,7 +67,7 @@ fn test_make_tuple() { let results = Machine::default().run(TestContext(hugr), []); - let x = results.try_read_wire_value(v3).unwrap(); + let x: Value = results.try_read_wire_value(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); } @@ -83,9 +83,9 @@ fn test_unpack_tuple_const() { let results = Machine::default().run(TestContext(hugr), []); - let o1_r = results.try_read_wire_value(o1).unwrap(); + let o1_r: Value = results.try_read_wire_value(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r = results.try_read_wire_value(o2).unwrap(); + let o2_r: Value = results.try_read_wire_value(o2).unwrap(); assert_eq!(o2_r, Value::true_val()); } @@ -106,7 +106,7 @@ fn test_tail_loop_never_iterates() { let results = Machine::default().run(TestContext(hugr), []); - let o_r = results.try_read_wire_value(tl_o).unwrap(); + let o_r: Value = results.try_read_wire_value(tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::NeverContinues), @@ -291,9 +291,9 @@ fn test_conditional() { )); let results = Machine::default().run(TestContext(hugr), [(0.into(), arg_pv)]); - let cond_r1 = results.try_read_wire_value(cond_o1).unwrap(); + let cond_r1: Value = results.try_read_wire_value(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results.try_read_wire_value(cond_o2).is_err()); + assert!(results.try_read_wire_value::(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); From 8cac194e534df144a5ef7cb5a8a7ce5c651c5ce6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 14:45:41 +0000 Subject: [PATCH 202/281] tweaks --- hugr-passes/Cargo.toml | 2 +- hugr-passes/src/dataflow.rs | 18 ++++---- hugr-passes/src/dataflow/datalog.rs | 54 +++++++++++------------ hugr-passes/src/dataflow/partial_value.rs | 1 - hugr-passes/src/dataflow/test.rs | 1 + 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 88d7fc62b..818aa069c 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -30,4 +30,4 @@ extension_inference = ["hugr-core/extension_inference"] rstest = { workspace = true } proptest = { workspace = true } proptest-derive = { workspace = true } -proptest-recurse = { version = "0.5.0" } \ No newline at end of file +proptest-recurse = { version = "0.5.0" } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 0b7bf03ed..dd3e6d2c0 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -20,7 +20,7 @@ use hugr_core::{Hugr, HugrView, Node}; /// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: ConstLoader + std::ops::Deref { /// Type of view contained within this context. (Ideally we'd constrain - /// by `std::ops::Deref` but that's not stable yet.) type View: HugrView; /// Given lattice values for each input, update lattice values for the (dataflow) outputs. @@ -41,12 +41,14 @@ pub trait DFContext: ConstLoader + std::ops::Deref { } } -/// Trait for loading [PartialValue]s from constants in a Hugr. The default -/// traverses [Sum](Value::Sum) constants to their non-Sum leaves but represents -/// each leaf as [PartialValue::Top]. +/// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. +/// Implementors will likely want to override some/all of [Self::value_from_opaque], +/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { - /// Produces an abstract value from a constant. The default impl - /// traverses the constant [Value] to its leaves ([Value::Extension] and [Value::Function]), + /// Produces a [PartialValue] from a constant. The default impl (expected + /// to be appropriate in most cases) traverses [Sum](Value::Sum) constants + /// to their leaves ([Value::Extension] and [Value::Function]), /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { @@ -65,8 +67,8 @@ pub trait ConstLoader { None } - /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node (that has been loaded - /// via a [LoadFunction]), if possible. + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node + /// (that has been loaded via a [LoadFunction]), if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. /// /// [FuncDefn]: hugr_core::ops::FuncDefn diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index facab8595..81f415fd0 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,5 +1,8 @@ //! [ascent] datalog implementation of analysis. +use std::collections::HashSet; +use std::hash::RandomState; + use ascent::lattice::BoundedLattice; use itertools::Itertools; @@ -51,20 +54,16 @@ impl Machine { // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis // (Consider: for a conditional that selects *either* the unknown input *or* value V, // analysis must produce Top == we-know-nothing, not `V` !) - let mut have_inputs = - vec![false; context.signature(root).unwrap_or_default().input_count()]; + let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( + (0..context.signature(root).unwrap_or_default().input_count()).map(IncomingPort::from), + ); self.0.iter().for_each(|(n, p, _)| { if n == &root { - if let Some(e) = have_inputs.get_mut(p.index()) { - *e = true; - } + need_inputs.remove(p); } }); - for (i, b) in have_inputs.into_iter().enumerate() { - if !b { - self.0 - .push((root, IncomingPort::from(i), PartialValue::Top)); - } + for p in need_inputs { + self.0.push((root, p, PartialValue::Top)); } // Note/TODO, if analysis is running on a subregion then we should do similar // for any nonlocal edges providing values from outside the region. @@ -109,6 +108,7 @@ pub(super) fn run_datalog>( // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); + // Outputs to inputs in_wire_value(n, ip, v) <-- in_wire(n, ip), if let Some((m, op)) = ctx.single_linked_output(*n, *ip), out_wire_value(m, op, v); @@ -120,10 +120,11 @@ pub(super) fn run_datalog>( if let Some(sig) = ctx.signature(*n), if sig.input_ports().contains(p); - // Assemble in_value_row from in_value's + // Assemble node_in_value_row from in_wire_value's node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + // Interpret leaf ops out_wire_value(n, p, v) <-- node(n), let op_t = ctx.get_optype(*n), @@ -133,7 +134,7 @@ pub(super) fn run_datalog>( if let Some(outs) = propagate_leaf_op(&ctx, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); - // DFG + // DFG -------------------- relation dfg_node(Node); // is a `DFG` dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); @@ -143,9 +144,7 @@ pub(super) fn run_datalog>( out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), output_child(dfg, o), in_wire_value(o, p, v); - - // TailLoop - + // TailLoop -------------------- // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), if ctx.get_optype(*tl).is_tail_loop(), @@ -169,13 +168,11 @@ pub(super) fn run_datalog>( if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 for (out_p, v) in fields.enumerate(); - // Conditional - relation conditional_node(Node); // is a `Conditional` + // Conditional -------------------- // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); - - conditional_node(n)<-- node(n), if ctx.get_optype(*n).is_conditional(); - case_node(cond, i, case) <-- conditional_node(cond), + case_node(cond, i, case) <-- node(cond), + if ctx.get_optype(*cond).is_conditional(), for (i, case) in ctx.children(*cond).enumerate(), if ctx.get_optype(case).is_case(); @@ -195,17 +192,17 @@ pub(super) fn run_datalog>( output_child(case, o), in_wire_value(o, o_p, v); - // In `Conditional` , child `Case` is reachable given our knowledge of predicate + // In `Conditional` , child `Case` is reachable given our knowledge of predicate: relation case_reachable(Node, Node); case_reachable(cond, case) <-- case_node(cond, i, case), in_wire_value(cond, IncomingPort::from(0), v), if v.supports_tag(*i); - // CFG + // CFG -------------------- relation cfg_node(Node); // is a `CFG` cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); - // In `CFG` , basic block is reachable given our knowledge of predicates + // In `CFG` , basic block is reachable given our knowledge of predicates: relation bb_reachable(Node, Node); bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), @@ -223,7 +220,7 @@ pub(super) fn run_datalog>( in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to - // come out of Value outports of . + // come out of Value outports of : relation _cfg_succ_dest(Node, Node, Node); _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), @@ -242,7 +239,7 @@ pub(super) fn run_datalog>( if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); - // Call + // Call -------------------- relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), @@ -327,7 +324,7 @@ fn propagate_leaf_op( )) } OpType::ExtensionOp(e) => { - // Interpret op. + // Interpret op using DFContext let init = if ins.iter().contains(&PartialValue::Bottom) { // So far we think one or more inputs can't happen. // So, don't pollute outputs with Top, and wait for better knowledge of inputs. @@ -337,9 +334,8 @@ fn propagate_leaf_op( PartialValue::Top }; let mut outs = vec![init; num_outs]; - // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, - // thus keeping PartialValue hidden, but AbstractValues - // are not necessarily convertible to Value. + // It might be nice to convert these to [(IncomingPort, Value)], or some concrete value, + // for the context, but PV contains more information, and try_into_value may fail. ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); Some(ValueRow::from_iter(outs)) } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 12f1733d3..0086629a1 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -362,7 +362,6 @@ impl TryFrom> for Value { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - // println!("join {self:?}\n{:?}", &other); match (&*self, other) { (Self::Top, _) => false, (_, other @ Self::Top) => { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 057242d03..01b1474dd 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -312,6 +312,7 @@ fn xor_and_cfg() -> Hugr { let mut builder = CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap(); let false_c = builder.add_constant(Value::false_val()); + // entry (x, y) => if x {A(y, x=true)} else B(y)} let entry_outs = [type_row![BOOL_T;2], type_row![BOOL_T]]; let mut entry = builder From 19ec62edabdd5ceb38f69e7bc291f2e058a3bc91 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 15:34:56 +0000 Subject: [PATCH 203/281] v1 test...fails: tuple can't be removed, nothing removed, but at least '9' added --- hugr-passes/src/const_fold/test.rs | 48 ++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 8e594a8af..927a15c58 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,9 +1,12 @@ -use hugr_core::builder::{inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr}; +use std::collections::HashSet; + +use hugr_core::builder::{endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{ - const_ok, sum_with_error, ConstError, ConstString, UnpackTuple, BOOL_T, ERROR_TYPE, STRING_TYPE, + const_ok, sum_with_error, ConstError, ConstString, MakeTuple, UnpackTuple, BOOL_T, ERROR_TYPE, STRING_TYPE }; use hugr_core::extension::{ExtensionRegistry, PRELUDE}; use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::ops::{OpTag, OpTrait}; use hugr_core::ops::{constant::CustomConst, OpType, Value}; use hugr_core::std_extensions::arithmetic::{ self, @@ -1557,3 +1560,44 @@ fn test_fold_int_ops() { let expected = Value::true_val(); assert_fully_folded(&h, &expected); } + +#[test] +fn test_via_tuple() { + // fn(x) -> let (a,b,c) = (4,5,x) // make tuple, unpack tuple + // in (a+b)+c + let mut builder = DFGBuilder::new(endo_sig(INT_TYPES[3].clone())).unwrap(); + let [x] = builder.input_wires_arr(); + let cst4 = builder.add_load_value(ConstInt::new_u(3, 4).unwrap()); + let cst5 = builder.add_load_value(ConstInt::new_u(3, 5).unwrap()); + let tuple_ty = TypeRow::from(vec![INT_TYPES[3].clone(); 3]); + let tup = builder.add_dataflow_op(MakeTuple::new(tuple_ty.clone()), [cst4, cst5, x]).unwrap(); + let untup = builder.add_dataflow_op(UnpackTuple::new(tuple_ty), tup.outputs()).unwrap(); + let [a,b,c] = untup.outputs_arr(); + let add_ab = builder.add_dataflow_op(IntOpDef::iadd.with_log_width(3), [a,b]).unwrap(); + let [ab] = add_ab.outputs_arr(); + let res = builder.add_dataflow_op(IntOpDef::iadd.with_log_width(3), [ab,c]).unwrap(); + let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); + let mut hugr = builder.finish_hugr_with_outputs(res.outputs(), ®).unwrap(); + + constant_fold_pass(&mut hugr, ®); + println!("{}", hugr.mermaid_string()); + + // We expect: root dfg, input, output, const 9, load constant, iadd, MAKETUPLE, UNPACKTUPLE + let mut expected_op_tags = Vec::from_iter([OpTag::Dfg, OpTag::Input, OpTag::Output, OpTag::Const, OpTag::LoadConst]); + let mut expected_opaque_names: HashSet<_, std::hash::RandomState> = HashSet::from_iter(["MakeTuple", "UnpackTuple", "iadd"]); + for n in hugr.nodes() { + let t = hugr.get_optype(n); + if let Some(e) = t.as_extension_op() { + let removed = expected_opaque_names.remove(e.def().name().as_str()); + assert!(removed); + } else { + let Some((idx, _)) = expected_op_tags.iter().enumerate().find(|(_,v)| **v == t.tag()) else { + panic!("Did not expect {:?}", t); + }; + expected_op_tags.remove(idx); + if let Some(c) = t.as_const() { + assert_eq!(c.value, ConstInt::new_u(3, 9).unwrap().into()) + } + } + } +} \ No newline at end of file From 72a3034fd2525b727f9ba21742dad7c8cf8a9a81 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 15:40:55 +0000 Subject: [PATCH 204/281] v2, discard unknown in middle --- hugr-passes/src/const_fold/test.rs | 66 +++++++++++++++++------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 927a15c58..8a6d0092b 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -2,12 +2,13 @@ use std::collections::HashSet; use hugr_core::builder::{endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{ - const_ok, sum_with_error, ConstError, ConstString, MakeTuple, UnpackTuple, BOOL_T, ERROR_TYPE, STRING_TYPE + const_ok, sum_with_error, ConstError, ConstString, MakeTuple, UnpackTuple, BOOL_T, ERROR_TYPE, + STRING_TYPE, }; use hugr_core::extension::{ExtensionRegistry, PRELUDE}; use hugr_core::hugr::hugrmut::HugrMut; -use hugr_core::ops::{OpTag, OpTrait}; use hugr_core::ops::{constant::CustomConst, OpType, Value}; +use hugr_core::ops::{OpTag, OpTrait}; use hugr_core::std_extensions::arithmetic::{ self, conversions::ConvertOpDef, @@ -1562,42 +1563,49 @@ fn test_fold_int_ops() { } #[test] -fn test_via_tuple() { - // fn(x) -> let (a,b,c) = (4,5,x) // make tuple, unpack tuple - // in (a+b)+c +fn test_via_part_unknown_tuple() { + // fn(x) -> let (a,_b,c) = (4,x,5) // make tuple, unpack tuple + // in a+b let mut builder = DFGBuilder::new(endo_sig(INT_TYPES[3].clone())).unwrap(); let [x] = builder.input_wires_arr(); let cst4 = builder.add_load_value(ConstInt::new_u(3, 4).unwrap()); let cst5 = builder.add_load_value(ConstInt::new_u(3, 5).unwrap()); let tuple_ty = TypeRow::from(vec![INT_TYPES[3].clone(); 3]); - let tup = builder.add_dataflow_op(MakeTuple::new(tuple_ty.clone()), [cst4, cst5, x]).unwrap(); - let untup = builder.add_dataflow_op(UnpackTuple::new(tuple_ty), tup.outputs()).unwrap(); - let [a,b,c] = untup.outputs_arr(); - let add_ab = builder.add_dataflow_op(IntOpDef::iadd.with_log_width(3), [a,b]).unwrap(); - let [ab] = add_ab.outputs_arr(); - let res = builder.add_dataflow_op(IntOpDef::iadd.with_log_width(3), [ab,c]).unwrap(); + let tup = builder + .add_dataflow_op(MakeTuple::new(tuple_ty.clone()), [cst4, x, cst5]) + .unwrap(); + let untup = builder + .add_dataflow_op(UnpackTuple::new(tuple_ty), tup.outputs()) + .unwrap(); + let [a, _b, c] = untup.outputs_arr(); + let res = builder + .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [a, c]) + .unwrap(); let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); - let mut hugr = builder.finish_hugr_with_outputs(res.outputs(), ®).unwrap(); - + let mut hugr = builder + .finish_hugr_with_outputs(res.outputs(), ®) + .unwrap(); + constant_fold_pass(&mut hugr, ®); - println!("{}", hugr.mermaid_string()); - // We expect: root dfg, input, output, const 9, load constant, iadd, MAKETUPLE, UNPACKTUPLE - let mut expected_op_tags = Vec::from_iter([OpTag::Dfg, OpTag::Input, OpTag::Output, OpTag::Const, OpTag::LoadConst]); - let mut expected_opaque_names: HashSet<_, std::hash::RandomState> = HashSet::from_iter(["MakeTuple", "UnpackTuple", "iadd"]); + // We expect: root dfg, input, output, const 9, load constant, iadd + let mut expected_op_tags: HashSet<_, std::hash::RandomState> = [ + OpTag::Dfg, + OpTag::Input, + OpTag::Output, + OpTag::Const, + OpTag::LoadConst, + ] + .map(|t| t.to_string()) + .into_iter() + .collect(); for n in hugr.nodes() { let t = hugr.get_optype(n); - if let Some(e) = t.as_extension_op() { - let removed = expected_opaque_names.remove(e.def().name().as_str()); - assert!(removed); - } else { - let Some((idx, _)) = expected_op_tags.iter().enumerate().find(|(_,v)| **v == t.tag()) else { - panic!("Did not expect {:?}", t); - }; - expected_op_tags.remove(idx); - if let Some(c) = t.as_const() { - assert_eq!(c.value, ConstInt::new_u(3, 9).unwrap().into()) - } + let removed = expected_op_tags.remove(&t.tag().to_string()); + assert!(removed); + if let Some(c) = t.as_const() { + assert_eq!(c.value, ConstInt::new_u(3, 9).unwrap().into()) } } -} \ No newline at end of file + assert!(expected_op_tags.is_empty()); +} From a339cb0eb269f9ec57a6da53cf2e4765718a999a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 17:30:45 +0000 Subject: [PATCH 205/281] docs --- hugr-passes/src/const_fold.rs | 4 +++- hugr-passes/src/const_fold/value_handle.rs | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index e7eb382d5..ff558bf1f 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -1,7 +1,7 @@ +#![warn(missing_docs)] //! An (example) use of the [super::dataflow](dataflow-analysis framework) //! to perform constant-folding. -// These are pub because this "example" is used for testing the framework. pub mod value_handle; use std::collections::{HashSet, VecDeque}; @@ -36,6 +36,7 @@ pub struct ConstFoldPass { } impl ConstFoldPass { + /// Sets the validation level used before and after the pass is run pub fn validation_level(mut self, level: ValidationLevel) -> Self { self.validation = level; self @@ -89,6 +90,7 @@ impl ConstFoldPass { Ok(()) } + /// Run the pass using this configuration pub fn run( &self, hugr: &mut H, diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index eb92aa018..6dffd54b6 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -1,3 +1,5 @@ +//! Total equality (and hence [AbstractValue] support for [Value]s +//! (by adding a source-Node and part unhashable constants) use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -9,6 +11,7 @@ use itertools::Either; use crate::dataflow::AbstractValue; +/// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) #[derive(Clone, Debug)] pub struct HashedConst { hash: u64, @@ -39,9 +42,15 @@ impl Hash for HashedConst { } } +/// A [Node] (expected to be a [Const]) and, for Sum constants, optionally, +/// indices of elements (nested arbitrarily deeply) within that. +/// +/// [Const]: hugr_core::ops::Const #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum NodePart { + /// The specified-index'th field of the [Sum](Value::Sum) constant identified by the RHS Field(usize, Box), + /// The entire value produced by the node Node(Node), } @@ -53,13 +62,18 @@ impl NodePart { } } +/// An [Eq]-able and [Hash]-able leaf (non-[Sum](Value::Sum)) Value #[derive(Clone, Debug)] pub enum ValueHandle { + /// A [Value::Extension] that has been hashed Hashable(HashedConst), + /// Either a [Value::Extension] that can't be hashed, or a [Value::Function]. Unhashable(NodePart, Either, Arc>), } impl ValueHandle { + /// Makes a new instance from an [OpaqueValue] given the node and (for a [Sum](Value::Sum)) + /// field indices within that (used only if the custom constant is not hashable). pub fn new_opaque(node: Node, fields: &[usize], val: OpaqueValue) -> Self { let arc = Arc::new(val); HashedConst::try_new(arc.clone()).map_or( @@ -68,6 +82,7 @@ impl ValueHandle { ) } + /// New instance for a [Value::Function] found within a node pub fn new_const_hugr(node: Node, fields: &[usize], val: Box) -> Self { Self::Unhashable(NodePart::new(node, fields), Either::Right(Arc::from(val))) } From b871f6173f92d1b965cb929626fca234efd8cd76 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 17:57:05 +0000 Subject: [PATCH 206/281] cfg --- hugr-passes/src/const_fold.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index ff558bf1f..557f825fe 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -115,12 +115,8 @@ impl ConstFoldPass { }; if h.get_optype(n).is_cfg() { for bb in h.children(n) { - if results.bb_reachable(bb).unwrap() - && needed.insert(bb) - && h.get_optype(bb).is_dataflow_block() - { - q.push_back(bb); - } + //if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates + q.push_back(bb); } } else if let Some(inout) = h.get_io(n) { // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. From 2c615c6731aa98bb371c3512821bc45347e8e1fa Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 18:00:09 +0000 Subject: [PATCH 207/281] drop unused From for Box --- hugr-core/src/ops/constant.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index b267ba777..5b9309407 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -301,12 +301,6 @@ impl From for OpaqueValue { } } -impl From for Box { - fn from(value: OpaqueValue) -> Self { - value.v - } -} - impl PartialEq for OpaqueValue { fn eq(&self, other: &Self) -> bool { self.value().equal_consts(other.value()) From 5c880e4dfbc86dee4826c15833a7e3196fd0c6ae Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 31 Oct 2024 14:41:17 +0000 Subject: [PATCH 208/281] ...and patch up try_into_wire_value's with type annotations --- hugr-passes/src/const_fold.rs | 19 ++++++++++--------- hugr-passes/src/const_fold/test.rs | 7 ++----- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 557f825fe..c05a73083 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -7,13 +7,10 @@ use std::collections::{HashSet, VecDeque}; use hugr_core::{ extension::ExtensionRegistry, - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, - ops::{ - constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - }, + hugr::hugrmut::HugrMut, + hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView}, + ops::constant::OpaqueValue, + ops::{handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, Value}, types::{EdgeKind, TypeArg}, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire, }; @@ -69,7 +66,9 @@ impl ConstFoldPass { (!results.hugr().get_optype(src).is_load_constant()).then_some(( n, ip, - results.try_read_wire_value(Wire::new(src, outp)).ok()?, + results + .try_read_wire_value::(Wire::new(src, outp)) + .ok()?, )) }) .collect::>(); @@ -139,7 +138,9 @@ impl ConstFoldPass { let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { EdgeKind::Value(_) => { results.hugr().get_optype(src).is_load_constant() - || results.try_read_wire_value(Wire::new(src, op)).is_err() + || results + .try_read_wire_value::(Wire::new(src, op)) + .is_err() } EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, EdgeKind::ControlFlow => panic!(), diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 8a6d0092b..fe78944e5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -103,11 +103,8 @@ fn f2c(f: f64) -> Value { // c = a + b fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { fn unwrap_float(pv: PartialValue) -> f64 { - pv.try_into_value::(&FLOAT64_TYPE) - .unwrap() - .get_custom_value::() - .unwrap() - .value() + let v: Value = pv.try_into_value(&FLOAT64_TYPE).unwrap(); + v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); let mut temp = Hugr::default(); From fb3816e4c6e14c2a2e8889b0a6695f1214969600 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 12:41:49 +0000 Subject: [PATCH 209/281] Massively simplify xor_and_cfg, no need for conditionals --- hugr-passes/src/dataflow/test.rs | 94 ++++++++++---------------------- 1 file changed, 30 insertions(+), 64 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 01b1474dd..c8961ea2c 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -304,90 +304,56 @@ fn test_conditional() { // A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) #[fixture] fn xor_and_cfg() -> Hugr { - // Entry - // /0 1\ - // A --1-> B A(x=true, y) => if y then X(false, true) else B(x=true) - // \0 / B(z) => X(z,false) + // Entry branch on first arg, passes arguments on unchanged + // /T F\ + // A --T-> B A(x=true, y) branch on second arg, passing (first arg == true, false) + // \F / B(w,v) => X(v,w) // > X < + // Inputs received: + // Entry A B X + // F,F - F,F F,F + // F,T - F,T T,F + // T,F T,F - T,F + // T,T T,T T,F F,T let mut builder = CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap(); - let false_c = builder.add_constant(Value::false_val()); - // entry (x, y) => if x {A(y, x=true)} else B(y)} - let entry_outs = [type_row![BOOL_T;2], type_row![BOOL_T]]; - let mut entry = builder - .entry_builder(entry_outs.clone(), type_row![]) + // entry (x, y) => (if x then A else B)(x=true, y) + let entry = builder + .entry_builder(vec![type_row![]; 2], type_row![BOOL_T;2]) .unwrap(); let [in_x, in_y] = entry.input_wires_arr(); - let mut cond = entry - .conditional_builder( - (vec![type_row![]; 2], in_x), - [], - Type::new_sum(entry_outs.clone()).into(), - ) - .unwrap(); - let mut if_x_true = cond.case_builder(1).unwrap(); - let br_to_a = if_x_true - .add_dataflow_op(Tag::new(0, entry_outs.to_vec()), [in_y, in_x]) - .unwrap(); - if_x_true.finish_with_outputs(br_to_a.outputs()).unwrap(); - let mut if_x_false = cond.case_builder(0).unwrap(); - let br_to_b = if_x_false - .add_dataflow_op(Tag::new(1, entry_outs.into()), [in_y]) - .unwrap(); - if_x_false.finish_with_outputs(br_to_b.outputs()).unwrap(); - - let [res] = cond.finish_sub_container().unwrap().outputs_arr(); - let entry = entry.finish_with_outputs(res, []).unwrap(); + let entry = entry.finish_with_outputs(in_x, [in_x, in_y]).unwrap(); - // A(y, z always true) => if y {X(false, z)} else {B(z)} - let a_outs = vec![type_row![BOOL_T], type_row![]]; + // A(x==true, y) => (if y then B else X)(x, false) let mut a = builder .block_builder( type_row![BOOL_T; 2], - a_outs.clone(), - type_row![BOOL_T], // Trailing z common to both branches + vec![type_row![]; 2], + type_row![BOOL_T; 2], ) .unwrap(); - let [in_y, in_z] = a.input_wires_arr(); + let [in_x, in_y] = a.input_wires_arr(); + let false_w1 = a.add_load_value(Value::false_val()); + let a = a.finish_with_outputs(in_y, [in_x, false_w1]).unwrap(); - let mut cond = a - .conditional_builder( - (vec![type_row![]; 2], in_y), - [], - Type::new_sum(a_outs.clone()).into(), - ) - .unwrap(); - let mut if_y_true = cond.case_builder(1).unwrap(); - let false_w1 = if_y_true.load_const(&false_c); - let br_to_x = if_y_true - .add_dataflow_op(Tag::new(0, a_outs.clone()), [false_w1]) - .unwrap(); - if_y_true.finish_with_outputs(br_to_x.outputs()).unwrap(); - let mut if_y_false = cond.case_builder(0).unwrap(); - let br_to_b = if_y_false.add_dataflow_op(Tag::new(1, a_outs), []).unwrap(); - if_y_false.finish_with_outputs(br_to_b.outputs()).unwrap(); - let [res] = cond.finish_sub_container().unwrap().outputs_arr(); - let a = a.finish_with_outputs(res, [in_z]).unwrap(); - - // B(v) => X(v, false) + // B(w, v) => X(v, w) let mut b = builder - .block_builder(type_row![BOOL_T], [type_row![]], type_row![BOOL_T; 2]) + .block_builder(type_row![BOOL_T; 2], [type_row![]], type_row![BOOL_T; 2]) .unwrap(); - let [in_v] = b.input_wires_arr(); - let false_w2 = b.load_const(&false_c); + let [in_w, in_v] = b.input_wires_arr(); let [control] = b .add_dataflow_op(Tag::new(0, vec![type_row![]]), []) .unwrap() .outputs_arr(); - let b = b.finish_with_outputs(control, [in_v, false_w2]).unwrap(); + let b = b.finish_with_outputs(control, [in_v, in_w]).unwrap(); let x = builder.exit_block(); - builder.branch(&entry, 0, &a).unwrap(); - builder.branch(&entry, 1, &b).unwrap(); - builder.branch(&a, 0, &x).unwrap(); - builder.branch(&a, 1, &b).unwrap(); + builder.branch(&entry, 1, &a).unwrap(); // if true + builder.branch(&entry, 0, &b).unwrap(); // if false + builder.branch(&a, 1, &b).unwrap(); // if true + builder.branch(&a, 0, &x).unwrap(); // if false builder.branch(&b, 0, &x).unwrap(); builder.finish_hugr(&EMPTY_REG).unwrap() } @@ -402,9 +368,9 @@ fn xor_and_cfg() -> Hugr { #[case(pv_false(), pv_true_or_false(), pv_true_or_false(), pv_false())] #[case(pv_false(), PartialValue::Top, PartialValue::Top, pv_false())] // if !inp0 then out0=inp1 #[case(pv_true_or_false(), pv_true(), pv_true_or_false(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_true_or_false())] #[case(PartialValue::Top, pv_true(), pv_true_or_false(), PartialValue::Top)] -#[case(PartialValue::Top, pv_false(), PartialValue::Top, pv_false())] +#[case(PartialValue::Top, pv_false(), PartialValue::Top, PartialValue::Top)] fn test_cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, From 19571f6a9d2bcd94c1afe49b9b136f7505a4d493 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 13:12:19 +0000 Subject: [PATCH 210/281] Use tru/fals constants --- hugr-passes/src/dataflow/test.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index c8961ea2c..a300965b1 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -350,10 +350,11 @@ fn xor_and_cfg() -> Hugr { let x = builder.exit_block(); - builder.branch(&entry, 1, &a).unwrap(); // if true - builder.branch(&entry, 0, &b).unwrap(); // if false - builder.branch(&a, 1, &b).unwrap(); // if true - builder.branch(&a, 0, &x).unwrap(); // if false + let [fals, tru]: [usize; 2] = [0, 1]; + builder.branch(&entry, tru, &a).unwrap(); // if true + builder.branch(&entry, fals, &b).unwrap(); // if false + builder.branch(&a, tru, &b).unwrap(); // if true + builder.branch(&a, fals, &x).unwrap(); // if false builder.branch(&b, 0, &x).unwrap(); builder.finish_hugr(&EMPTY_REG).unwrap() } From ec2cc78c280cdd9a7ba52071c301327f656ecc40 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 11:35:15 +0000 Subject: [PATCH 211/281] Add recursive might_diverge, assume true for all CFGs --- hugr-passes/src/const_fold.rs | 50 ++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index c05a73083..618ddf45b 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -27,9 +27,10 @@ use crate::{ /// A configuration for the Constant Folding pass. pub struct ConstFoldPass { validation: ValidationLevel, - /// If true, allow to skip evaluating loops (whose results are not needed) even if - /// we are not sure they will terminate. (If they definitely terminate then fair game.) - pub allow_skip_loops: bool, + /// If true, allow to skip evaluating [TailLoop]s and [CFGs] (whose results are known, + /// or not needed) even if we are not sure they will terminate. That is, allow + /// transforming a potentially non-terminating graph into a definitely-terminating one. + pub allow_increase_termination: bool, } impl ConstFoldPass { @@ -112,6 +113,7 @@ impl ConstFoldPass { if !needed.insert(n) { continue; }; + if h.get_optype(n).is_cfg() { for bb in h.children(n) { //if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates @@ -121,15 +123,13 @@ impl ConstFoldPass { // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. q.extend(inout); // Input also necessary for legality even if unreachable - // Also add on anything that might not terminate. We might also allow a custom predicate for extension ops? - for ch in h.children(n) { - if h.get_optype(ch).is_cfg() - || (!self.allow_skip_loops - && h.get_optype(ch).is_tail_loop() - && results.tail_loop_terminates(ch).unwrap() - != TailLoopTermination::NeverContinues) - { - q.push_back(ch); + if !self.allow_increase_termination { + // Also add on anything that might not terminate (even if results not required - + // if its results are required we'll add it by following dataflow, below.) + for ch in h.children(n) { + if self.might_diverge(results, ch) { + q.push_back(ch); + } } } } @@ -152,6 +152,32 @@ impl ConstFoldPass { } } } + + // "Diverge" aka "never-terminate" + // TODO would be more efficient to compute this bottom-up and cache (dynamic programming) + fn might_diverge( + &self, + results: &AnalysisResults>, + n: Node, + ) -> bool { + let op = results.hugr().get_optype(n); + if op.is_cfg() { + // TODO if the CFG has no cycles (that are possible given predicates) + // then we could say it definitely terminates (i.e. return false) + true + } else if op.is_tail_loop() + && results.tail_loop_terminates(n).unwrap() != TailLoopTermination::NeverContinues + { + // If we can even figure out the number of iterations is bounded that would allow returning false. + true + } else { + // Node does not introduce non-termination, but still non-terminates if any of its children does + results + .hugr() + .children(n) + .any(|ch| self.might_diverge(results, ch)) + } + } } /// Exhaustively apply constant folding to a HUGR. From 6ba0b25557594d82fb4b4aed5b6af8dca87558e7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 11:35:38 +0000 Subject: [PATCH 212/281] Clarify "non-exhaustive" in comment --- hugr-passes/src/const_fold.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 618ddf45b..382888451 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -144,7 +144,7 @@ impl ConstFoldPass { } EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, EdgeKind::ControlFlow => panic!(), - _ => true, // needed for non-exhaustive; not knowing what it is, assume the worst + _ => true, // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst }; if needs_predecessor { q.push_back(src); From e681602973bbd64a6ed81290e65eacc008657448 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 16:59:25 +0000 Subject: [PATCH 213/281] ConstFoldContext needs only & not &mut --- hugr-passes/src/const_fold.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 382888451..fb8d68a14 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -44,7 +44,7 @@ impl ConstFoldPass { fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { let results = Machine::default().run(ConstFoldContext(hugr), []); let mut keep_nodes = HashSet::new(); - self.find_needed_nodes(&results, results.hugr().root(), &mut keep_nodes); + self.find_needed_nodes(&results, hugr.root(), &mut keep_nodes); let remove_nodes = results .hugr() @@ -53,18 +53,18 @@ impl ConstFoldPass { .collect::>(); let wires_to_break = keep_nodes .into_iter() - .flat_map(|n| results.hugr().node_inputs(n).map(move |ip| (n, ip))) + .flat_map(|n| hugr.node_inputs(n).map(move |ip| (n, ip))) .filter(|(n, ip)| { matches!( - results.hugr().get_optype(*n).port_kind(*ip).unwrap(), + hugr.get_optype(*n).port_kind(*ip).unwrap(), EdgeKind::Value(_) ) }) // Note we COULD filter out (avoid breaking) wires from other nodes that we are keeping. // This would insert fewer constants, but potentially expose less parallelism. .filter_map(|(n, ip)| { - let (src, outp) = results.hugr().single_linked_output(n, ip).unwrap(); - (!results.hugr().get_optype(src).is_load_constant()).then_some(( + let (src, outp) = hugr.single_linked_output(n, ip).unwrap(); + (!hugr.get_optype(src).is_load_constant()).then_some(( n, ip, results @@ -137,7 +137,7 @@ impl ConstFoldPass { for (src, op) in h.all_linked_outputs(n) { let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { EdgeKind::Value(_) => { - results.hugr().get_optype(src).is_load_constant() + h.get_optype(src).is_load_constant() || results .try_read_wire_value::(Wire::new(src, op)) .is_err() @@ -185,7 +185,7 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { ConstFoldPass::default().run(h, reg).unwrap() } -struct ConstFoldContext<'a, H>(&'a mut H); +struct ConstFoldContext<'a, H>(&'a H); impl<'a, H: HugrView> std::ops::Deref for ConstFoldContext<'a, H> { type Target = H; From fc44ded02e8f0da574baf16f6e2575ed668f86cb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 18:05:22 +0000 Subject: [PATCH 214/281] Test constant folding a TailLoop (cannot remove loop) --- hugr-passes/src/const_fold/test.rs | 103 +++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index fe78944e5..a41b38a57 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1606,3 +1606,106 @@ fn test_via_part_unknown_tuple() { } assert!(expected_op_tags.is_empty()); } + +#[test] +fn test_tail_loop() { + let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); + let cst5 = ConstInt::new_u(3, 5).unwrap(); + let h = { + let mut builder = DFGBuilder::new(inout_sig(BOOL_T, INT_TYPES[3].clone())).unwrap(); + let [bool_w] = builder.input_wires_arr(); + let cst5 = builder.add_load_value(cst5.clone()); + let tlb = builder + .tail_loop_builder([], [(INT_TYPES[3].clone(), cst5)], type_row![]) + .unwrap(); + let [i] = tlb.input_wires_arr(); + // Loop either always breaks, or always iterates, depending on the boolean input + let [loop_out_w] = tlb.finish_with_outputs(bool_w, [i]).unwrap().outputs_arr(); + // The output of the loop is the constant, if the loop terminates + let add = builder + .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [cst5, loop_out_w]) + .unwrap(); + + builder + .finish_hugr_with_outputs(add.outputs(), ®) + .unwrap() + }; + let mut h2 = h.clone(); + constant_fold_pass(&mut h2, ®); + assert_eq!(h2.node_count(), 12); + let tl = h2 + .nodes() + .filter(|n| h.get_optype(*n).is_tail_loop()) + .exactly_one() + .ok() + .unwrap(); + let mut dfg_nodes = Vec::new(); + let mut loop_nodes = Vec::new(); + for n in h2.nodes() { + if let Some(p) = h2.get_parent(n) { + if p == h2.root() { + dfg_nodes.push(n) + } else { + assert_eq!(p, tl); + loop_nodes.push(n); + } + } + } + let tag_string = |n: &Node| format!("{:?}", h2.get_optype(*n).tag()); + assert_eq!( + dfg_nodes + .iter() + .map(tag_string) + .sorted() + .collect::>(), + Vec::from([ + "Const", + "Const", + "Input", + "LoadConst", + "LoadConst", + "Output", + "TailLoop" + ]) + ); + + assert_eq!( + loop_nodes.iter().map(tag_string).collect::>(), + Vec::from(["Input", "Output", "Const", "LoadConst"]) + ); + + // In the loop, we have a new constant 5 instead of using the loop input + let [loop_in, loop_out] = h2.get_io(tl).unwrap(); + let (loop_cst, v) = loop_nodes + .into_iter() + .filter_map(|n| h2.get_optype(n).as_const().map(|c| (n, c.value()))) + .exactly_one() + .unwrap(); + assert_eq!(v, &cst5.clone().into()); + let loop_lcst = h2.output_neighbours(loop_cst).exactly_one().unwrap(); + assert_eq!(h2.get_parent(loop_lcst), Some(tl)); + assert_eq!( + h2.all_linked_inputs(loop_lcst).collect::>(), + vec![(loop_out, IncomingPort::from(1))] + ); + assert!(h2.input_neighbours(loop_in).next().is_none()); + + // Outer DFG contains two constants (we know) - a 5, used by the loop, and a 10, output. + let [_, root_out] = h2.get_io(h2.root()).unwrap(); + let mut cst5 = Some(cst5.into()); + for n in dfg_nodes { + let Some(cst) = h2.get_optype(n).as_const() else { + continue; + }; + let lcst = h2.output_neighbours(n).exactly_one().unwrap(); + let target = h2.output_neighbours(lcst).exactly_one().unwrap(); + if Some(cst.value()) == cst5.as_ref() { + cst5 = None; + assert_eq!(target, tl); + } else { + assert_eq!(cst.value(), &ConstInt::new_u(3, 10).unwrap().into()); + assert_eq!(target, root_out) + } + } + assert!(cst5.is_none()); // Found in loop +} From ae9cb7cc588644cbe5716b2e9f6a779dec8cdab5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 18:10:45 +0000 Subject: [PATCH 215/281] Add pub fn allow_increase_termination, test allows removing tail-loop --- hugr-passes/src/const_fold.rs | 7 +++++++ hugr-passes/src/const_fold/test.rs | 9 ++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index fb8d68a14..cf8a6fc02 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -40,6 +40,13 @@ impl ConstFoldPass { self } + /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their + /// result (if/when they do terminate) is either known or not needed + pub fn allow_increase_termination(mut self) -> Self { + self.allow_increase_termination = true; + self + } + /// Run the Constant Folding pass. fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { let results = Machine::default().run(ConstFoldContext(hugr), []); diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index a41b38a57..f70ac8806 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -25,7 +25,7 @@ use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; -use super::{constant_fold_pass, ConstFoldContext, ValueHandle}; +use super::{constant_fold_pass, ConstFoldContext, ConstFoldPass, ValueHandle}; use crate::dataflow::{ConstLoader, DFContext, PartialValue}; #[rstest] @@ -1708,4 +1708,11 @@ fn test_tail_loop() { } } assert!(cst5.is_none()); // Found in loop + + let mut h3 = h.clone(); + ConstFoldPass::default() + .allow_increase_termination() + .run(&mut h3, ®) + .unwrap(); + assert_fully_folded(&h3, &ConstInt::new_u(3, 10).unwrap().into()); } From 3a43ae1da1ca12e037d5ab1a577df4afc2e7f3f2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 21:25:32 +0000 Subject: [PATCH 216/281] Hide allow_increase_termination field now we have method --- hugr-passes/src/const_fold.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index cf8a6fc02..b5bb827d6 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -27,10 +27,7 @@ use crate::{ /// A configuration for the Constant Folding pass. pub struct ConstFoldPass { validation: ValidationLevel, - /// If true, allow to skip evaluating [TailLoop]s and [CFGs] (whose results are known, - /// or not needed) even if we are not sure they will terminate. That is, allow - /// transforming a potentially non-terminating graph into a definitely-terminating one. - pub allow_increase_termination: bool, + allow_increase_termination: bool, } impl ConstFoldPass { @@ -41,7 +38,10 @@ impl ConstFoldPass { } /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their - /// result (if/when they do terminate) is either known or not needed + /// result (if/when they do terminate) is either known or not needed. + /// + /// [TailLoop]: hugr_core::ops::TailLoop + /// [CFG]: hugr_core::ops::CFG pub fn allow_increase_termination(mut self) -> Self { self.allow_increase_termination = true; self From f0e84acdc9560922a878d349eebf2de917c9c2c8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 21:28:37 +0000 Subject: [PATCH 217/281] Allow adding inputs (pub traverse_value), fix needing port-num to distinguish --- hugr-passes/src/const_fold.rs | 35 +++++++++++++++++++++++++---------- hugr-passes/src/dataflow.rs | 2 +- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index b5bb827d6..76ceb5a09 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -3,22 +3,27 @@ //! to perform constant-folding. pub mod value_handle; -use std::collections::{HashSet, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use hugr_core::{ extension::ExtensionRegistry, - hugr::hugrmut::HugrMut, - hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView}, - ops::constant::OpaqueValue, - ops::{handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, Value}, + hugr::{ + hugrmut::HugrMut, + views::{DescendantsGraph, ExtractHugr, HierarchyView}, + }, + ops::{ + constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, + Value, + }, types::{EdgeKind, TypeArg}, - HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire, + HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, }; use value_handle::ValueHandle; use crate::{ dataflow::{ - AnalysisResults, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination, + traverse_value, AnalysisResults, ConstLoader, DFContext, Machine, PartialValue, + TailLoopTermination, }, validation::{ValidatePassError, ValidationLevel}, }; @@ -28,6 +33,7 @@ use crate::{ pub struct ConstFoldPass { validation: ValidationLevel, allow_increase_termination: bool, + inputs: HashMap, } impl ConstFoldPass { @@ -49,7 +55,17 @@ impl ConstFoldPass { /// Run the Constant Folding pass. fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - let results = Machine::default().run(ConstFoldContext(hugr), []); + let fresh_node = Node::from(portgraph::NodeIndex::new( + hugr.nodes().max().map_or(0, |n| n.index() + 1), + )); + let inputs = self.inputs.iter().map(|(p, v)| { + ( + p.clone(), + traverse_value(&ConstFoldContext(hugr), fresh_node, &mut vec![p.index()], v), + ) + }); + + let results = Machine::default().run(ConstFoldContext(hugr), inputs); let mut keep_nodes = HashSet::new(); self.find_needed_nodes(&results, hugr.root(), &mut keep_nodes); @@ -261,8 +277,7 @@ impl<'a, H: HugrView> DFContext for ConstFoldContext<'a, H> { }) .collect::>(); for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { - // Hmmm, we should (at least) key the value also by p - outs[p.index()] = self.value_from_const(node, &v); + outs[p.index()] = traverse_value(self, node, &mut vec![p.index()], &v); } } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index dd3e6d2c0..f052b32b5 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -79,7 +79,7 @@ pub trait ConstLoader { } } -fn traverse_value( +pub fn traverse_value( s: &(impl ConstLoader + ?Sized), n: Node, fields: &mut Vec, From 02262d67573037d6dd1017f86eeb673be9ffd59d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 21:35:11 +0000 Subject: [PATCH 218/281] traverse_value => partial_from_const, fix docs --- hugr-passes/src/const_fold.rs | 6 +++--- hugr-passes/src/dataflow.rs | 15 ++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 76ceb5a09..3594baedc 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -22,7 +22,7 @@ use value_handle::ValueHandle; use crate::{ dataflow::{ - traverse_value, AnalysisResults, ConstLoader, DFContext, Machine, PartialValue, + partial_from_const, AnalysisResults, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination, }, validation::{ValidatePassError, ValidationLevel}, @@ -61,7 +61,7 @@ impl ConstFoldPass { let inputs = self.inputs.iter().map(|(p, v)| { ( p.clone(), - traverse_value(&ConstFoldContext(hugr), fresh_node, &mut vec![p.index()], v), + partial_from_const(&ConstFoldContext(hugr), fresh_node, &mut vec![p.index()], v), ) }); @@ -277,7 +277,7 @@ impl<'a, H: HugrView> DFContext for ConstFoldContext<'a, H> { }) .collect::>(); for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { - outs[p.index()] = traverse_value(self, node, &mut vec![p.index()], &v); + outs[p.index()] = partial_from_const(self, node, &mut vec![p.index()], &v); } } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f052b32b5..f86d3c283 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -47,12 +47,9 @@ pub trait DFContext: ConstLoader + std::ops::Deref { /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { /// Produces a [PartialValue] from a constant. The default impl (expected - /// to be appropriate in most cases) traverses [Sum](Value::Sum) constants - /// to their leaves ([Value::Extension] and [Value::Function]), - /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], - /// and builds nested [PartialValue::new_variant] to represent the structure. + /// to be appropriate in most cases) uses [partial_from_const]. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { - traverse_value(self, n, &mut Vec::new(), cst) + partial_from_const(self, n, &mut Vec::new(), cst) } /// Produces an abstract value from an [OpaqueValue], if possible. @@ -79,7 +76,11 @@ pub trait ConstLoader { } } -pub fn traverse_value( +/// Converts a constant [Value] by traversing [Sum](Value::Sum) constants +/// to their leaves ([Value::Extension] and [Value::Function]), +/// converting these using [ConstLoader::value_from_opaque] and [ConstLoader::value_from_const_hugr], +/// and building nested [PartialValue::new_variant]s to represent the structure. +pub fn partial_from_const( s: &(impl ConstLoader + ?Sized), n: Node, fields: &mut Vec, @@ -89,7 +90,7 @@ pub fn traverse_value( Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { let elems = values.iter().enumerate().map(|(idx, elem)| { fields.push(idx); - let r = traverse_value(s, n, fields, elem); + let r = partial_from_const(s, n, fields, elem); fields.pop(); r }); From 81b8179c0ec928faedf314ba4c806008c7851d81 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 21:52:11 +0000 Subject: [PATCH 219/281] fixup! Allow adding inputs --- hugr-passes/src/const_fold.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 3594baedc..f8e7402af 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -53,6 +53,17 @@ impl ConstFoldPass { self } + /// Specifies any number of external inputs to provide to the Hugr (on root-node + /// in-ports). Each supercedes any previous value on the same in-port. + pub fn with_inputs( + mut self, + inputs: impl IntoIterator, Value)>, + ) -> Self { + self.inputs + .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); + self + } + /// Run the Constant Folding pass. fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { let fresh_node = Node::from(portgraph::NodeIndex::new( From a038608e4e3cbad26bdfebebd118f1becc6150ed Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 21:52:59 +0000 Subject: [PATCH 220/281] Expand test by providing always-break input, refactor/split --- hugr-passes/src/const_fold/test.rs | 103 +++++++++++++++++------------ 1 file changed, 61 insertions(+), 42 deletions(-) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index f70ac8806..b7dc6f4db 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1607,33 +1607,39 @@ fn test_via_part_unknown_tuple() { assert!(expected_op_tags.is_empty()); } -#[test] -fn test_tail_loop() { +fn tail_loop_hugr(int_cst: ConstInt) -> (Hugr, ExtensionRegistry) { let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); + let int_ty = int_cst.get_type(); + let lw = int_cst.log_width(); + let mut builder = DFGBuilder::new(inout_sig(BOOL_T, int_ty.clone())).unwrap(); + let [bool_w] = builder.input_wires_arr(); + let lcst = builder.add_load_value(int_cst); + let tlb = builder + .tail_loop_builder([], [(int_ty, lcst)], type_row![]) + .unwrap(); + let [i] = tlb.input_wires_arr(); + // Loop either always breaks, or always iterates, depending on the boolean input + let [loop_out_w] = tlb.finish_with_outputs(bool_w, [i]).unwrap().outputs_arr(); + // The output of the loop is the constant, if the loop terminates + let add = builder + .add_dataflow_op(IntOpDef::iadd.with_log_width(lw), [lcst, loop_out_w]) + .unwrap(); + + let hugr = builder + .finish_hugr_with_outputs(add.outputs(), ®) + .unwrap(); + (hugr, reg) +} + +#[test] +fn test_tail_loop_unknown() { let cst5 = ConstInt::new_u(3, 5).unwrap(); - let h = { - let mut builder = DFGBuilder::new(inout_sig(BOOL_T, INT_TYPES[3].clone())).unwrap(); - let [bool_w] = builder.input_wires_arr(); - let cst5 = builder.add_load_value(cst5.clone()); - let tlb = builder - .tail_loop_builder([], [(INT_TYPES[3].clone(), cst5)], type_row![]) - .unwrap(); - let [i] = tlb.input_wires_arr(); - // Loop either always breaks, or always iterates, depending on the boolean input - let [loop_out_w] = tlb.finish_with_outputs(bool_w, [i]).unwrap().outputs_arr(); - // The output of the loop is the constant, if the loop terminates - let add = builder - .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [cst5, loop_out_w]) - .unwrap(); - - builder - .finish_hugr_with_outputs(add.outputs(), ®) - .unwrap() - }; - let mut h2 = h.clone(); - constant_fold_pass(&mut h2, ®); - assert_eq!(h2.node_count(), 12); - let tl = h2 + let (mut h, reg) = tail_loop_hugr(cst5.clone()); + + constant_fold_pass(&mut h, ®); + // Must keep the loop, even though we know the output, in case the output doesn't happen + assert_eq!(h.node_count(), 12); + let tl = h .nodes() .filter(|n| h.get_optype(*n).is_tail_loop()) .exactly_one() @@ -1641,9 +1647,9 @@ fn test_tail_loop() { .unwrap(); let mut dfg_nodes = Vec::new(); let mut loop_nodes = Vec::new(); - for n in h2.nodes() { - if let Some(p) = h2.get_parent(n) { - if p == h2.root() { + for n in h.nodes() { + if let Some(p) = h.get_parent(n) { + if p == h.root() { dfg_nodes.push(n) } else { assert_eq!(p, tl); @@ -1651,7 +1657,7 @@ fn test_tail_loop() { } } } - let tag_string = |n: &Node| format!("{:?}", h2.get_optype(*n).tag()); + let tag_string = |n: &Node| format!("{:?}", h.get_optype(*n).tag()); assert_eq!( dfg_nodes .iter() @@ -1675,30 +1681,30 @@ fn test_tail_loop() { ); // In the loop, we have a new constant 5 instead of using the loop input - let [loop_in, loop_out] = h2.get_io(tl).unwrap(); + let [loop_in, loop_out] = h.get_io(tl).unwrap(); + assert!(h.input_neighbours(loop_in).next().is_none()); let (loop_cst, v) = loop_nodes .into_iter() - .filter_map(|n| h2.get_optype(n).as_const().map(|c| (n, c.value()))) + .filter_map(|n| h.get_optype(n).as_const().map(|c| (n, c.value()))) .exactly_one() .unwrap(); assert_eq!(v, &cst5.clone().into()); - let loop_lcst = h2.output_neighbours(loop_cst).exactly_one().unwrap(); - assert_eq!(h2.get_parent(loop_lcst), Some(tl)); + let loop_lcst = h.output_neighbours(loop_cst).exactly_one().unwrap(); + assert_eq!(h.get_parent(loop_lcst), Some(tl)); assert_eq!( - h2.all_linked_inputs(loop_lcst).collect::>(), + h.all_linked_inputs(loop_lcst).collect::>(), vec![(loop_out, IncomingPort::from(1))] ); - assert!(h2.input_neighbours(loop_in).next().is_none()); // Outer DFG contains two constants (we know) - a 5, used by the loop, and a 10, output. - let [_, root_out] = h2.get_io(h2.root()).unwrap(); + let [_, root_out] = h.get_io(h.root()).unwrap(); let mut cst5 = Some(cst5.into()); for n in dfg_nodes { - let Some(cst) = h2.get_optype(n).as_const() else { + let Some(cst) = h.get_optype(n).as_const() else { continue; }; - let lcst = h2.output_neighbours(n).exactly_one().unwrap(); - let target = h2.output_neighbours(lcst).exactly_one().unwrap(); + let lcst = h.output_neighbours(n).exactly_one().unwrap(); + let target = h.output_neighbours(lcst).exactly_one().unwrap(); if Some(cst.value()) == cst5.as_ref() { cst5 = None; assert_eq!(target, tl); @@ -1708,11 +1714,24 @@ fn test_tail_loop() { } } assert!(cst5.is_none()); // Found in loop +} - let mut h3 = h.clone(); +#[test] +fn test_tail_loop_never_iterates() { + let (mut h, reg) = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); + ConstFoldPass::default() + .with_inputs([(0, Value::true_val())]) // true = 1 = break + .run(&mut h, ®) + .unwrap(); + assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); +} + +#[test] +fn test_tail_loop_increase_termination() { + let (mut h, reg) = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); ConstFoldPass::default() .allow_increase_termination() - .run(&mut h3, ®) + .run(&mut h, ®) .unwrap(); - assert_fully_folded(&h3, &ConstInt::new_u(3, 10).unwrap().into()); + assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); } From d90190b0bd73237f67261da610ba8b644772a13a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 5 Nov 2024 09:06:43 +0000 Subject: [PATCH 221/281] refactor: assert_fully_folded takes HugrView --- hugr-passes/src/const_fold/test.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b7dc6f4db..2a6f533fd 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -62,7 +62,7 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { } /// Check that a hugr just loads and returns a single expected constant. -pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { +pub fn assert_fully_folded(h: &impl HugrView, expected_value: &Value) { assert_fully_folded_with(h, |v| v == expected_value) } @@ -71,7 +71,7 @@ pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { /// /// [CustomConst::equals_const] is not required to be implemented. Use this /// function for Values containing such a `CustomConst`. -fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { +fn assert_fully_folded_with(h: &impl HugrView, check_value: impl Fn(&Value) -> bool) { let mut node_count = 0; for node in h.children(h.root()) { From 8f89c88540d3cf0faf700922de9d3d17c365d4cf Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 5 Nov 2024 10:49:29 +0000 Subject: [PATCH 222/281] refactor: find_needed_nodes works out root itself --- hugr-passes/src/const_fold.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index f8e7402af..3e4263a75 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -78,7 +78,7 @@ impl ConstFoldPass { let results = Machine::default().run(ConstFoldContext(hugr), inputs); let mut keep_nodes = HashSet::new(); - self.find_needed_nodes(&results, hugr.root(), &mut keep_nodes); + self.find_needed_nodes(&results, &mut keep_nodes); let remove_nodes = results .hugr() @@ -137,12 +137,11 @@ impl ConstFoldPass { fn find_needed_nodes( &self, results: &AnalysisResults>, - root: Node, needed: &mut HashSet, ) { let mut q = VecDeque::new(); - q.push_back(root); let h = results.hugr(); + q.push_back(h.root()); while let Some(n) = q.pop_front() { if !needed.insert(n) { continue; From a6bb507d8bb8f363b83ff019a6657c99691bd6a0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 5 Nov 2024 08:02:48 +0000 Subject: [PATCH 223/281] CFG test (failing atm) --- hugr-passes/src/const_fold/test.rs | 130 ++++++++++++++++++++++++++++- 1 file changed, 129 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 2a6f533fd..da0499f2f 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,12 +1,16 @@ use std::collections::HashSet; -use hugr_core::builder::{endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr}; +use hugr_core::builder::{ + endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + SubContainer, +}; use hugr_core::extension::prelude::{ const_ok, sum_with_error, ConstError, ConstString, MakeTuple, UnpackTuple, BOOL_T, ERROR_TYPE, STRING_TYPE, }; use hugr_core::extension::{ExtensionRegistry, PRELUDE}; use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::{constant::CustomConst, OpType, Value}; use hugr_core::ops::{OpTag, OpTrait}; use hugr_core::std_extensions::arithmetic::{ @@ -1735,3 +1739,127 @@ fn test_tail_loop_increase_termination() { .unwrap(); assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); } + +fn cfg_hugr() -> (Hugr, ExtensionRegistry) { + let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); + let int_ty = INT_TYPES[4].clone(); + let mut builder = DFGBuilder::new(inout_sig(vec![BOOL_T, BOOL_T], int_ty.clone())).unwrap(); + let [p, q] = builder.input_wires_arr(); + let int_cst = builder.add_load_value(ConstInt::new_u(4, 1).unwrap()); + let mut nested = builder + .dfg_builder_endo([(int_ty.clone(), int_cst)]) + .unwrap(); + let [i] = nested.input_wires_arr(); + let mut cfg = nested + .cfg_builder([(int_ty.clone(), i)], int_ty.clone().into()) + .unwrap(); + let mut entry = cfg.simple_entry_builder(int_ty.clone().into(), 2).unwrap(); + let [e_i] = entry.input_wires_arr(); + let e_cst7 = entry.add_load_value(ConstInt::new_u(4, 7).unwrap()); + let e_add = entry + .add_dataflow_op(IntOpDef::iadd.with_log_width(4), [e_cst7, e_i]) + .unwrap(); + let entry = entry.finish_with_outputs(p, e_add.outputs()).unwrap(); + + let mut a = cfg + .simple_block_builder(endo_sig(int_ty.clone()), 2) + .unwrap(); + let [a_i] = a.input_wires_arr(); + let a_cst3 = a.add_load_value(ConstInt::new_u(4, 3).unwrap()); + let a_add = a + .add_dataflow_op(IntOpDef::iadd.with_log_width(4), [a_cst3, a_i]) + .unwrap(); + let a = a.finish_with_outputs(q, a_add.outputs()).unwrap(); + + let x = cfg.exit_block(); + let [tru, fals] = [1, 0]; + cfg.branch(&entry, tru, &a).unwrap(); + cfg.branch(&entry, fals, &x).unwrap(); + cfg.branch(&a, tru, &entry).unwrap(); + cfg.branch(&a, fals, &x).unwrap(); + let cfg = cfg.finish_sub_container().unwrap(); + let nested = nested.finish_with_outputs(cfg.outputs()).unwrap(); + let hugr = builder + .finish_hugr_with_outputs(nested.outputs(), ®) + .unwrap(); + (hugr, reg) +} + +#[rstest] +#[case(&[(0,false)], true, false, Some(8))] +#[case(&[(0,true), (1,false)], true, true, Some(11))] +#[case(&[(1,false)], true, true, None)] +#[case(&[], false, false, None)] +fn test_cfg2( + #[case] inputs: &[(usize, bool)], + #[case] fold_entry: bool, + #[case] fold_blk: bool, + #[case] fold_res: Option, +) { + use hugr_core::ops::handle::BasicBlockID; + + let (backup, reg) = cfg_hugr(); + let mut hugr = backup.clone(); + let pass = ConstFoldPass::default() + .with_inputs(inputs.into_iter().map(|(p, b)| (*p, Value::from_bool(*b)))); + pass.run(&mut hugr, ®).unwrap(); + // CFG inside DFG retained + let nested = hugr + .children(hugr.root()) + .filter(|n| hugr.get_optype(*n).is_dfg()) + .exactly_one() + .unwrap(); + let cfg = hugr + .nodes() + .filter(|n| hugr.get_optype(*n).is_cfg()) + .exactly_one() + .ok() + .unwrap(); + assert_eq!(hugr.get_parent(cfg), Some(nested)); + let [entry, exit, a] = hugr.children(cfg).collect::>().try_into().unwrap(); + assert!(hugr.get_optype(exit).is_exit_block()); + for (blk, is_folded, folded_cst, unfolded_cst) in + [(entry, fold_entry, 8, 7), (a, fold_blk, 11, 3)] + { + if is_folded { + assert_fully_folded( + &DescendantsGraph::::try_new(&hugr, blk).unwrap(), + &ConstInt::new_u(4, folded_cst).unwrap().into(), + ); + } else { + let mut expected_tags = + HashSet::from(["Input", "Output", "Leaf", "Const", "LoadConst"]); + for ch in hugr.children(blk) { + let tag = format!("{:?}", hugr.get_optype(ch).tag()); + assert!(expected_tags.remove(tag.as_str()), "Not found: {}", tag); + if let Some(cst) = hugr.get_optype(ch).as_const() { + assert_eq!( + cst.value(), + &ConstInt::new_u(4, unfolded_cst).unwrap().into() + ); + } else if let Some(op) = hugr.get_optype(ch).as_extension_op() { + assert_eq!(op.def().name(), "iadd"); + } + } + } + } + let output_src = hugr + .input_neighbours(hugr.get_io(hugr.root()).unwrap()[1]) + .exactly_one() + .unwrap(); + if let Some(res_int) = fold_res { + let res_v = ConstInt::new_u(4, res_int as _).unwrap().into(); + assert!(hugr.get_optype(output_src).is_load_constant()); + let output_cst = hugr.input_neighbours(output_src).exactly_one().unwrap(); + let cst = hugr.get_optype(output_cst).as_const().unwrap(); + assert_eq!(cst.value(), &res_v); + + let mut hugr2 = backup; + pass.allow_increase_termination() + .run(&mut hugr2, ®) + .unwrap(); + assert_fully_folded(&hugr2, &res_v); + } else { + assert_eq!(output_src, nested); + } +} From 3e05639ac9146f78f31e8332bd20836132320c3c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 5 Nov 2024 10:54:25 +0000 Subject: [PATCH 224/281] fix/change-policy: don't break edges from root input --- hugr-passes/src/const_fold.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 3e4263a75..631eee365 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -79,6 +79,7 @@ impl ConstFoldPass { let results = Machine::default().run(ConstFoldContext(hugr), inputs); let mut keep_nodes = HashSet::new(); self.find_needed_nodes(&results, &mut keep_nodes); + let [root_inp, _] = hugr.get_io(hugr.root()).unwrap(); let remove_nodes = results .hugr() @@ -98,7 +99,9 @@ impl ConstFoldPass { // This would insert fewer constants, but potentially expose less parallelism. .filter_map(|(n, ip)| { let (src, outp) = hugr.single_linked_output(n, ip).unwrap(); - (!hugr.get_optype(src).is_load_constant()).then_some(( + // Avoid breaking edges from existing LoadConstant (we'd only add another) + // or froom root input node (i.e. hardcoding provided "external inputs" into graph) + (!hugr.get_optype(src).is_load_constant() && src != root_inp).then_some(( n, ip, results From 5b2a2c9bab766a836be02f50a9e541855b592151 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 5 Nov 2024 10:58:09 +0000 Subject: [PATCH 225/281] Fix: don't follow CFG edges, but don't panic either --- hugr-passes/src/const_fold.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 631eee365..a063fbe26 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -179,7 +179,7 @@ impl ConstFoldPass { .is_err() } EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, - EdgeKind::ControlFlow => panic!(), + EdgeKind::ControlFlow => false, // we always include all children of a CFG above _ => true, // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst }; if needs_predecessor { From dcf76a6130fb79670a41f2f68779d5ee2d807026 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 5 Nov 2024 11:03:57 +0000 Subject: [PATCH 226/281] clippy, moving might_diverge outside ConstFoldPass+parametrize --- hugr-passes/src/const_fold.rs | 55 +++++++++++++++++------------------ 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index a063fbe26..1858d7043 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -22,8 +22,8 @@ use value_handle::ValueHandle; use crate::{ dataflow::{ - partial_from_const, AnalysisResults, ConstLoader, DFContext, Machine, PartialValue, - TailLoopTermination, + partial_from_const, AbstractValue, AnalysisResults, ConstLoader, DFContext, Machine, + PartialValue, TailLoopTermination, }, validation::{ValidatePassError, ValidationLevel}, }; @@ -71,7 +71,7 @@ impl ConstFoldPass { )); let inputs = self.inputs.iter().map(|(p, v)| { ( - p.clone(), + *p, partial_from_const(&ConstFoldContext(hugr), fresh_node, &mut vec![p.index()], v), ) }); @@ -163,7 +163,7 @@ impl ConstFoldPass { // Also add on anything that might not terminate (even if results not required - // if its results are required we'll add it by following dataflow, below.) for ch in h.children(n) { - if self.might_diverge(results, ch) { + if might_diverge(results, ch) { q.push_back(ch); } } @@ -188,31 +188,30 @@ impl ConstFoldPass { } } } +} - // "Diverge" aka "never-terminate" - // TODO would be more efficient to compute this bottom-up and cache (dynamic programming) - fn might_diverge( - &self, - results: &AnalysisResults>, - n: Node, - ) -> bool { - let op = results.hugr().get_optype(n); - if op.is_cfg() { - // TODO if the CFG has no cycles (that are possible given predicates) - // then we could say it definitely terminates (i.e. return false) - true - } else if op.is_tail_loop() - && results.tail_loop_terminates(n).unwrap() != TailLoopTermination::NeverContinues - { - // If we can even figure out the number of iterations is bounded that would allow returning false. - true - } else { - // Node does not introduce non-termination, but still non-terminates if any of its children does - results - .hugr() - .children(n) - .any(|ch| self.might_diverge(results, ch)) - } +// "Diverge" aka "never-terminate" +// TODO would be more efficient to compute this bottom-up and cache (dynamic programming) +fn might_diverge( + results: &AnalysisResults>, + n: Node, +) -> bool { + let op = results.hugr().get_optype(n); + if op.is_cfg() { + // TODO if the CFG has no cycles (that are possible given predicates) + // then we could say it definitely terminates (i.e. return false) + true + } else if op.is_tail_loop() + && results.tail_loop_terminates(n).unwrap() != TailLoopTermination::NeverContinues + { + // If we can even figure out the number of iterations is bounded that would allow returning false. + true + } else { + // Node does not introduce non-termination, but still non-terminates if any of its children does + results + .hugr() + .children(n) + .any(|ch| might_diverge(results, ch)) } } From 92488fe63db85dd1b5046565156adc00e4890e17 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 5 Nov 2024 11:06:51 +0000 Subject: [PATCH 227/281] Improve const_fold module doc --- hugr-passes/src/const_fold.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 1858d7043..61c72e6d2 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -1,6 +1,6 @@ #![warn(missing_docs)] -//! An (example) use of the [super::dataflow](dataflow-analysis framework) -//! to perform constant-folding. +//! Constant-folding pass. +//! An (example) use of the [dataflow analysis framework](super::dataflow). pub mod value_handle; use std::collections::{HashMap, HashSet, VecDeque}; From 69d0f5e59a5a787a6e222daaca622ca03114332a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 10:00:15 +0000 Subject: [PATCH 228/281] try_into_value: reorder type params, separate out where clause --- hugr-passes/src/dataflow/partial_value.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 0086629a1..992a72444 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -150,10 +150,13 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_value]. - pub fn try_into_value + TryFrom, Error = SE>>( + pub fn try_into_value( self, typ: &Type, - ) -> Result, ExtractValueError> { + ) -> Result, ExtractValueError> + where + V2: TryFrom + TryFrom, Error = SE>, + { let Ok((k, v)) = self.0.iter().exactly_one() else { return Err(ExtractValueError::MultipleVariants(self)); }; @@ -334,10 +337,10 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_value] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_value + TryFrom, Error = SE>>( - self, - typ: &Type, - ) -> Result> { + pub fn try_into_value(self, typ: &Type) -> Result> + where + V2: TryFrom + TryFrom, Error = SE>, + { match self { Self::Value(v) => V2::try_from(v.clone()) .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), From 2624ee85f7607726b1b82d00bb918f70942717af Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 10:02:25 +0000 Subject: [PATCH 229/281] We don't actually use portgraph, nor downcast-rs --- hugr-passes/Cargo.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index caa7e0af0..311fe781c 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -14,9 +14,7 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.13.3" } -portgraph = { workspace = true } ascent = { version = "0.7.0" } -downcast-rs = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } From ec526e83d4c1cea0c2d45f2686672d6578b44d66 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 10:23:13 +0000 Subject: [PATCH 230/281] Import RandomState from std::collections::hash_map for rust 1.75 --- hugr-passes/src/dataflow/datalog.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 81f415fd0..dbca253da 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,7 +1,7 @@ //! [ascent] datalog implementation of analysis. -use std::collections::HashSet; -use std::hash::RandomState; +use std::collections::hash_map::RandomState; +use std::collections::HashSet; // Moves to std::hash in Rust 1.76 use ascent::lattice::BoundedLattice; use itertools::Itertools; From 5650ee40b2f5aed54657a221274ecc97b01d771b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 13:48:36 +0000 Subject: [PATCH 231/281] Use BREAK_TAG/CONTINUE_TAG --- hugr-passes/src/dataflow/datalog.rs | 12 +++++++----- hugr-passes/src/dataflow/test.rs | 26 +++++++++++++++++++++----- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index dbca253da..33ab4524d 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType}; +use hugr_core::ops::{OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -156,16 +156,18 @@ pub(super) fn run_datalog>( if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), input_child(tl, in_n), output_child(tl, out_n), - node_in_value_row(out_n, out_in_row), // get the whole input row for the output node - if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ...and select just what's possible for CONTINUE_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), output_child(tl, out_n), - node_in_value_row(out_n, out_in_row), // get the whole input row for the output node - if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ... and select just what's possible for BREAK_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), for (out_p, v) in fields.enumerate(); // Conditional -------------------- diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index a300965b1..44446760e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,6 +3,7 @@ use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; +use hugr_core::ops::TailLoop; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -94,7 +95,10 @@ fn test_tail_loop_never_iterates() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let r_v = Value::unit_sum(3, 6).unwrap(); let r_w = builder.add_load_value(r_v.clone()); - let tag = Tag::new(1, vec![type_row![], r_v.get_type().into()]); + let tag = Tag::new( + TailLoop::BREAK_TAG, + vec![type_row![], r_v.get_type().into()], + ); let tagged = builder.add_dataflow_op(tag, [r_w]).unwrap(); let tlb = builder @@ -117,8 +121,14 @@ fn test_tail_loop_never_iterates() { #[test] fn test_tail_loop_always_iterates() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); - let r_w = builder - .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let r_w = builder.add_load_value( + Value::sum( + TailLoop::CONTINUE_TAG, + [], + SumType::new([type_row![], BOOL_T.into()]), + ) + .unwrap(), + ); let true_w = builder.add_load_value(Value::true_val()); let tlb = builder @@ -221,13 +231,19 @@ fn test_tail_loop_containing_conditional() { .unwrap() .outputs_arr(); let cont = case0_b - .add_dataflow_op(Tag::new(0, body_out_variants.clone()), [next_input]) + .add_dataflow_op( + Tag::new(TailLoop::CONTINUE_TAG, body_out_variants.clone()), + [next_input], + ) .unwrap(); case0_b.finish_with_outputs(cont.outputs()).unwrap(); // Second iter 1(true, false) => exit with (true, false) let mut case1_b = cond.case_builder(1).unwrap(); let loop_res = case1_b - .add_dataflow_op(Tag::new(1, body_out_variants), case1_b.input_wires()) + .add_dataflow_op( + Tag::new(TailLoop::BREAK_TAG, body_out_variants), + case1_b.input_wires(), + ) .unwrap(); case1_b.finish_with_outputs(loop_res.outputs()).unwrap(); let [r] = cond.finish_sub_container().unwrap().outputs_arr(); From da2981cf3d8e840d9e833b602a317ae5e0213623 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 13:52:10 +0000 Subject: [PATCH 232/281] No, use make_break/make_continue for easy cases --- hugr-passes/src/dataflow/test.rs | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 44446760e..c0fbf395a 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -213,6 +213,7 @@ fn test_tail_loop_containing_conditional() { let mut tlb = builder .tail_loop_builder([(control_t, init)], [], type_row![BOOL_T; 2]) .unwrap(); + let tl = tlb.loop_signature().unwrap().clone(); let [in_w] = tlb.input_wires_arr(); // Branch on in_wire, so first iter 0(false, true)... @@ -230,22 +231,12 @@ fn test_tail_loop_containing_conditional() { .add_dataflow_op(Tag::new(1, control_variants), [b, a]) .unwrap() .outputs_arr(); - let cont = case0_b - .add_dataflow_op( - Tag::new(TailLoop::CONTINUE_TAG, body_out_variants.clone()), - [next_input], - ) - .unwrap(); - case0_b.finish_with_outputs(cont.outputs()).unwrap(); + let cont = case0_b.make_continue(tl.clone(), [next_input]).unwrap(); + case0_b.finish_with_outputs([cont]).unwrap(); // Second iter 1(true, false) => exit with (true, false) let mut case1_b = cond.case_builder(1).unwrap(); - let loop_res = case1_b - .add_dataflow_op( - Tag::new(TailLoop::BREAK_TAG, body_out_variants), - case1_b.input_wires(), - ) - .unwrap(); - case1_b.finish_with_outputs(loop_res.outputs()).unwrap(); + let loop_res = case1_b.make_break(tl, case1_b.input_wires()).unwrap(); + case1_b.finish_with_outputs([loop_res]).unwrap(); let [r] = cond.finish_sub_container().unwrap().outputs_arr(); let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); From 0b684549442525d868ebba0f48a98f7927d60c2a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 13:40:56 +0000 Subject: [PATCH 233/281] Refactor bb_reachable using then --- hugr-passes/src/dataflow/results.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 6c90e33b3..21d6b13c0 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -69,14 +69,11 @@ impl> AnalysisResults { let cfg = self.hugr().get_parent(bb)?; // Not really required...?? self.hugr().get_optype(cfg).as_cfg()?; let t = self.hugr().get_optype(bb); - if !t.is_dataflow_block() && !t.is_exit_block() { - return None; - }; - Some( + (t.is_dataflow_block() || t.is_exit_block()).then(|| { self.bb_reachable .iter() - .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), - ) + .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) + }) } /// Reads a concrete representation of the value on an output wire, if the lattice value From 4916e9d6b7a099a6a1db871f1cddf3c677b3b51e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 14:24:35 +0000 Subject: [PATCH 234/281] ConstLocation with Box - a lot of cloning --- hugr-passes/src/dataflow.rs | 58 ++++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index dd3e6d2c0..a1e84bad1 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -41,6 +41,42 @@ pub trait DFContext: ConstLoader + std::ops::Deref { } } +/// A location where a [Value] could be find in a Hugr. That is, +/// (perhaps deeply nested within [Value::Sum]s) within a [Node] +/// that is a [Const](hugr_core::ops::Const). +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum ConstLocation>> { + /// The specified-index'th field of the [Value::Sum] constant identified by the RHS + Field(usize, C), + /// The entire ([Const::value](hugr_core::ops::Const::value)) of the node. + Node(Node), +} + +struct SharedConstLocation<'a>(ConstLocation<&'a SharedConstLocation<'a>>); + +impl<'a> AsRef>> for SharedConstLocation<'a> { + fn as_ref(&self) -> &ConstLocation<&'a SharedConstLocation<'a>> { + &self.0 + } +} + +struct BoxedConstLocation(Box>); + +impl AsRef> for BoxedConstLocation { + fn as_ref(&self) -> &ConstLocation { + &self.0 + } +} + +impl<'a> From<&SharedConstLocation<'a>> for BoxedConstLocation { + fn from(value: &SharedConstLocation<'a>) -> Self { + BoxedConstLocation(Box::new(match value.0 { + ConstLocation::Node(n) => ConstLocation::Node(n), + ConstLocation::Field(idx, elem) => ConstLocation::Field(idx, elem.into()), + })) + } +} + /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. /// Implementors will likely want to override some/all of [Self::value_from_opaque], /// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults @@ -52,18 +88,18 @@ pub trait ConstLoader { /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { - traverse_value(self, n, &mut Vec::new(), cst) + traverse_value(self, SharedConstLocation(ConstLocation::Node(n)), cst) } /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_opaque(&self, _node: Node, _fields: &[usize], _val: &OpaqueValue) -> Option { + fn value_from_opaque(&self, _loc: SharedConstLocation, _val: &OpaqueValue) -> Option { None } /// Produces an abstract value from a Hugr in a [Value::Function], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_const_hugr(&self, _node: Node, _fields: &[usize], _h: &Hugr) -> Option { + fn value_from_const_hugr(&self, _loc: SharedConstLocation, _h: &Hugr) -> Option { None } @@ -81,26 +117,22 @@ pub trait ConstLoader { fn traverse_value( s: &(impl ConstLoader + ?Sized), - n: Node, - fields: &mut Vec, + loc: SharedConstLocation, cst: &Value, ) -> PartialValue { match cst { Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { - let elems = values.iter().enumerate().map(|(idx, elem)| { - fields.push(idx); - let r = traverse_value(s, n, fields, elem); - fields.pop(); - r - }); + let elems = values.iter().enumerate().map(|(idx, elem)| + traverse_value(s, SharedConstLocation(ConstLocation::Field(idx, &loc)), elem) + ); PartialValue::new_variant(*tag, elems) } Value::Extension { e } => s - .value_from_opaque(n, fields, e) + .value_from_opaque(loc, e) .map(PartialValue::from) .unwrap_or(PartialValue::Top), Value::Function { hugr } => s - .value_from_const_hugr(n, fields, hugr) + .value_from_const_hugr(loc, hugr) .map(PartialValue::from) .unwrap_or(PartialValue::Top), } From a4a64c011a7e669256d869d9f9bc9531d6e8f8d0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 15:18:35 +0000 Subject: [PATCH 235/281] No - Revert - just make ConstLocation store a reference --- hugr-passes/src/dataflow.rs | 39 +++++++------------------------------ 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index a1e84bad1..802f00022 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -45,38 +45,13 @@ pub trait DFContext: ConstLoader + std::ops::Deref { /// (perhaps deeply nested within [Value::Sum]s) within a [Node] /// that is a [Const](hugr_core::ops::Const). #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum ConstLocation>> { +pub enum ConstLocation<'a> { /// The specified-index'th field of the [Value::Sum] constant identified by the RHS - Field(usize, C), + Field(usize, &'a ConstLocation<'a>), /// The entire ([Const::value](hugr_core::ops::Const::value)) of the node. Node(Node), } -struct SharedConstLocation<'a>(ConstLocation<&'a SharedConstLocation<'a>>); - -impl<'a> AsRef>> for SharedConstLocation<'a> { - fn as_ref(&self) -> &ConstLocation<&'a SharedConstLocation<'a>> { - &self.0 - } -} - -struct BoxedConstLocation(Box>); - -impl AsRef> for BoxedConstLocation { - fn as_ref(&self) -> &ConstLocation { - &self.0 - } -} - -impl<'a> From<&SharedConstLocation<'a>> for BoxedConstLocation { - fn from(value: &SharedConstLocation<'a>) -> Self { - BoxedConstLocation(Box::new(match value.0 { - ConstLocation::Node(n) => ConstLocation::Node(n), - ConstLocation::Field(idx, elem) => ConstLocation::Field(idx, elem.into()), - })) - } -} - /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. /// Implementors will likely want to override some/all of [Self::value_from_opaque], /// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults @@ -88,18 +63,18 @@ pub trait ConstLoader { /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { - traverse_value(self, SharedConstLocation(ConstLocation::Node(n)), cst) + traverse_value(self, ConstLocation::Node(n), cst) } /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_opaque(&self, _loc: SharedConstLocation, _val: &OpaqueValue) -> Option { + fn value_from_opaque(&self, _loc: ConstLocation, _val: &OpaqueValue) -> Option { None } /// Produces an abstract value from a Hugr in a [Value::Function], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_const_hugr(&self, _loc: SharedConstLocation, _h: &Hugr) -> Option { + fn value_from_const_hugr(&self, _loc: ConstLocation, _h: &Hugr) -> Option { None } @@ -117,13 +92,13 @@ pub trait ConstLoader { fn traverse_value( s: &(impl ConstLoader + ?Sized), - loc: SharedConstLocation, + loc: ConstLocation, cst: &Value, ) -> PartialValue { match cst { Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { let elems = values.iter().enumerate().map(|(idx, elem)| - traverse_value(s, SharedConstLocation(ConstLocation::Field(idx, &loc)), elem) + traverse_value(s, ConstLocation::Field(idx, &loc), elem) ); PartialValue::new_variant(*tag, elems) } From bc39b76e9b3c6840db59071ecad352456beca48d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 15:19:48 +0000 Subject: [PATCH 236/281] {value=>partial}_from_const, takes ConstLoc, inline traverse_value --- hugr-passes/src/dataflow.rs | 42 ++++++++++++----------------- hugr-passes/src/dataflow/datalog.rs | 4 +-- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 802f00022..dca97fb77 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -62,8 +62,23 @@ pub trait ConstLoader { /// to their leaves ([Value::Extension] and [Value::Function]), /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. - fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { - traverse_value(self, ConstLocation::Node(n), cst) + fn partial_from_const(&self, loc: ConstLocation, cst: &Value) -> PartialValue { + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values.iter().enumerate().map(|(idx, elem)| { + self.partial_from_const(ConstLocation::Field(idx, &loc), elem) + }); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => self + .value_from_opaque(loc, e) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => self + .value_from_const_hugr(loc, hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } } /// Produces an abstract value from an [OpaqueValue], if possible. @@ -90,28 +105,5 @@ pub trait ConstLoader { } } -fn traverse_value( - s: &(impl ConstLoader + ?Sized), - loc: ConstLocation, - cst: &Value, -) -> PartialValue { - match cst { - Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { - let elems = values.iter().enumerate().map(|(idx, elem)| - traverse_value(s, ConstLocation::Field(idx, &loc), elem) - ); - PartialValue::new_variant(*tag, elems) - } - Value::Extension { e } => s - .value_from_opaque(loc, e) - .map(PartialValue::from) - .unwrap_or(PartialValue::Top), - Value::Function { hugr } => s - .value_from_const_hugr(loc, hugr) - .map(PartialValue::from) - .unwrap_or(PartialValue::Top), - } -} - #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 33ab4524d..3d0acf20b 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -11,7 +11,7 @@ use hugr_core::ops::{OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{AbstractValue, AnalysisResults, DFContext, PartialValue}; +use super::{AbstractValue, AnalysisResults, ConstLocation, DFContext, PartialValue}; type PV = PartialValue; @@ -308,7 +308,7 @@ fn propagate_leaf_op( Some(ValueRow::single_known( 1, 0, - ctx.value_from_const(n, const_val), + ctx.partial_from_const(ConstLocation::Node(n), const_val), )) } OpType::LoadFunction(load_op) => { From 92669db2c3576bea591d4daa4718625d1f35af9d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 15:25:48 +0000 Subject: [PATCH 237/281] Make ConstLocation Copy --- hugr-passes/src/dataflow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index dca97fb77..5ebcc1afb 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -44,7 +44,7 @@ pub trait DFContext: ConstLoader + std::ops::Deref { /// A location where a [Value] could be find in a Hugr. That is, /// (perhaps deeply nested within [Value::Sum]s) within a [Node] /// that is a [Const](hugr_core::ops::Const). -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum ConstLocation<'a> { /// The specified-index'th field of the [Value::Sum] constant identified by the RHS Field(usize, &'a ConstLocation<'a>), From 33c860721c2b914ac74382f88e664759755958b4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 15:38:22 +0000 Subject: [PATCH 238/281] ConstLocation is From; move partial_from_const out to toplev, no value_from_const --- hugr-passes/src/dataflow.rs | 59 +++++++++++++++++------------ hugr-passes/src/dataflow/datalog.rs | 4 +- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 5ebcc1afb..3769eaced 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -52,35 +52,17 @@ pub enum ConstLocation<'a> { Node(Node), } +impl<'a> From for ConstLocation<'a> { + fn from(value: Node) -> Self { + ConstLocation::Node(value) + } +} + /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. /// Implementors will likely want to override some/all of [Self::value_from_opaque], /// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { - /// Produces a [PartialValue] from a constant. The default impl (expected - /// to be appropriate in most cases) traverses [Sum](Value::Sum) constants - /// to their leaves ([Value::Extension] and [Value::Function]), - /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], - /// and builds nested [PartialValue::new_variant] to represent the structure. - fn partial_from_const(&self, loc: ConstLocation, cst: &Value) -> PartialValue { - match cst { - Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { - let elems = values.iter().enumerate().map(|(idx, elem)| { - self.partial_from_const(ConstLocation::Field(idx, &loc), elem) - }); - PartialValue::new_variant(*tag, elems) - } - Value::Extension { e } => self - .value_from_opaque(loc, e) - .map(PartialValue::from) - .unwrap_or(PartialValue::Top), - Value::Function { hugr } => self - .value_from_const_hugr(loc, hugr) - .map(PartialValue::from) - .unwrap_or(PartialValue::Top), - } - } - /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. fn value_from_opaque(&self, _loc: ConstLocation, _val: &OpaqueValue) -> Option { @@ -105,5 +87,34 @@ pub trait ConstLoader { } } +/// Produces a [PartialValue] from a constant. Traverses [Sum](Value::Sum) constants +/// to their leaves ([Value::Extension] and [Value::Function]), +/// converts these using [ConstLoader::value_from_opaque] and [ConstLoader::value_from_const_hugr], +/// and builds nested [PartialValue::new_variant] to represent the structure. +fn partial_from_const<'a, V>( + cl: &impl ConstLoader, + loc: impl Into>, + cst: &Value, +) -> PartialValue { + let loc = loc.into(); + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values + .iter() + .enumerate() + .map(|(idx, elem)| partial_from_const(cl, ConstLocation::Field(idx, &loc), elem)); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => cl + .value_from_opaque(loc, e) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => cl + .value_from_const_hugr(loc, hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } +} + #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 3d0acf20b..303b96acf 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -11,7 +11,7 @@ use hugr_core::ops::{OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{AbstractValue, AnalysisResults, ConstLocation, DFContext, PartialValue}; +use super::{partial_from_const, AbstractValue, AnalysisResults, DFContext, PartialValue}; type PV = PartialValue; @@ -308,7 +308,7 @@ fn propagate_leaf_op( Some(ValueRow::single_known( 1, 0, - ctx.partial_from_const(ConstLocation::Node(n), const_val), + partial_from_const(ctx, n, const_val), )) } OpType::LoadFunction(load_op) => { From 8b76135795810faabaad8429a78b26aa32f84df0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 18 Nov 2024 21:53:02 +0000 Subject: [PATCH 239/281] Generalize run to deal with Module(use main), and others; add run_lib --- hugr-passes/src/dataflow/datalog.rs | 93 ++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 303b96acf..4a1448cc5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowParent, NamedOp, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -29,46 +29,99 @@ impl Default for Machine { } impl Machine { - /// Provide initial values for some wires. - // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? - pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { + // Provide initial values for a wire - these will be `join`d with any computed. + // pub(crate) so can be used for tests. + pub(crate) fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue) { self.0.extend( - h.linked_inputs(wire.node(), wire.source()) - .map(|(n, inp)| (n, inp, value.clone())), + h.linked_inputs(w.node(), w.source()) + .map(|(n, inp)| (n, inp, v.clone())), ); } /// Run the analysis (iterate until a lattice fixpoint is reached), - /// given initial values for some of the root node inputs. - /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, - /// but should handle other containers.) + /// given initial values for some of the root node inputs. For a + /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. /// The context passed in allows interpretation of leaf operations. + /// + /// [Module]: OpType::Module pub fn run>( mut self, context: C, in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = context.root(); - self.0 - .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + // Some nodes do not accept values as dataflow inputs - for these + // we must find the corresponding Output node. + let out_node_parent = match context.get_optype(root) { + OpType::Module(_) => Some( + context + .children(root) + .find(|n| { + context + .get_optype(*n) + .as_func_defn() + .is_some_and(|f| f.name() == "main") + }) + .expect("Module must contain a 'main' function to be analysed"), + ), + OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => Some(root), + // Could also do Dfg above, but ok here too: + _ => None, // Just feed into node inputs + }; + // Now write values onto Input node out-wires or Outputs. // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis // (Consider: for a conditional that selects *either* the unknown input *or* value V, // analysis must produce Top == we-know-nothing, not `V` !) - let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( - (0..context.signature(root).unwrap_or_default().input_count()).map(IncomingPort::from), - ); - self.0.iter().for_each(|(n, p, _)| { - if n == &root { - need_inputs.remove(p); + if let Some(p) = out_node_parent { + let [inp, _] = context.get_io(p).unwrap(); + let mut vals = + vec![PartialValue::Top; context.signature(inp).unwrap().output_types().len()]; + for (ip, v) in in_values { + vals[ip.index()] = v; + } + for (i, v) in vals.into_iter().enumerate() { + self.prepopulate_wire(&*context, Wire::new(inp, i), v); + } + } else { + self.0 + .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( + (0..context.signature(root).unwrap_or_default().input_count()) + .map(IncomingPort::from), + ); + self.0.iter().for_each(|(n, p, _)| { + if n == &root { + need_inputs.remove(p); + } + }); + for p in need_inputs { + self.0.push((root, p, PartialValue::Top)); } - }); - for p in need_inputs { - self.0.push((root, p, PartialValue::Top)); } // Note/TODO, if analysis is running on a subregion then we should do similar // for any nonlocal edges providing values from outside the region. run_datalog(context, self.0) } + + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// for a [Module]-rooted Hugr where all functions are assumed callable + /// (from a client) with any arguments. + /// The context passed in allows interpretation of leaf operations. + pub fn run_lib>(mut self, context: C) -> AnalysisResults { + let root = context.root(); + if !context.get_optype(root).is_module() { + panic!("Hugr not Module-rooted") + } + for n in context.children(root) { + if let Some(fd) = context.get_optype(n).as_func_defn() { + let [inp, _] = context.get_io(n).unwrap(); + for p in 0..fd.inner_signature().input_count() { + self.prepopulate_wire(&*context, Wire::new(inp, p), PartialValue::Top); + } + } + } + run_datalog(context, self.0) + } } pub(super) fn run_datalog>( From c18cbea26b14d248a5a9a130d7fe82aff63b0c6f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 18 Nov 2024 22:03:15 +0000 Subject: [PATCH 240/281] Shorten the got-all-required-inputs check (build got_inputs) --- hugr-passes/src/dataflow/datalog.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 4a1448cc5..9efbdc041 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -85,17 +85,15 @@ impl Machine { } else { self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( - (0..context.signature(root).unwrap_or_default().input_count()) - .map(IncomingPort::from), - ); - self.0.iter().for_each(|(n, p, _)| { - if n == &root { - need_inputs.remove(p); + let got_inputs: HashSet<_, RandomState> = self + .0 + .iter() + .filter_map(|(n, p, _)| (n == &root).then_some(*p)) + .collect(); + for p in context.signature(root).unwrap_or_default().input_ports() { + if !got_inputs.contains(&p) { + self.0.push((root, p, PartialValue::Top)); } - }); - for p in need_inputs { - self.0.push((root, p, PartialValue::Top)); } } // Note/TODO, if analysis is running on a subregion then we should do similar From 1b64b4bbd598825d1313b175cb768094e99f830e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 18 Nov 2024 22:19:16 +0000 Subject: [PATCH 241/281] Shorten further...not as easy to follow --- hugr-passes/src/dataflow/datalog.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 9efbdc041..0b0b57ce2 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,8 +1,5 @@ //! [ascent] datalog implementation of analysis. -use std::collections::hash_map::RandomState; -use std::collections::HashSet; // Moves to std::hash in Rust 1.76 - use ascent::lattice::BoundedLattice; use itertools::Itertools; @@ -85,15 +82,13 @@ impl Machine { } else { self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let got_inputs: HashSet<_, RandomState> = self - .0 - .iter() - .filter_map(|(n, p, _)| (n == &root).then_some(*p)) - .collect(); - for p in context.signature(root).unwrap_or_default().input_ports() { - if !got_inputs.contains(&p) { - self.0.push((root, p, PartialValue::Top)); - } + let mut need_inputs = + vec![true; context.signature(root).unwrap_or_default().input_count()]; + for (_, p, _) in self.0.iter().filter(|(n, _, _)| n == &root) { + need_inputs[p.index()] = false; + } + for (i, _) in need_inputs.into_iter().enumerate().filter(|(_, b)| *b) { + self.0.push((root, i.into(), PartialValue::Top)); } } // Note/TODO, if analysis is running on a subregion then we should do similar From 39b8df16e6c05297ad403b593095fcf3144afb72 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 09:14:26 +0000 Subject: [PATCH 242/281] Revert "Shorten further...not as easy to follow" This reverts commit 1b64b4bbd598825d1313b175cb768094e99f830e. --- hugr-passes/src/dataflow/datalog.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0b0b57ce2..9efbdc041 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,5 +1,8 @@ //! [ascent] datalog implementation of analysis. +use std::collections::hash_map::RandomState; +use std::collections::HashSet; // Moves to std::hash in Rust 1.76 + use ascent::lattice::BoundedLattice; use itertools::Itertools; @@ -82,13 +85,15 @@ impl Machine { } else { self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let mut need_inputs = - vec![true; context.signature(root).unwrap_or_default().input_count()]; - for (_, p, _) in self.0.iter().filter(|(n, _, _)| n == &root) { - need_inputs[p.index()] = false; - } - for (i, _) in need_inputs.into_iter().enumerate().filter(|(_, b)| *b) { - self.0.push((root, i.into(), PartialValue::Top)); + let got_inputs: HashSet<_, RandomState> = self + .0 + .iter() + .filter_map(|(n, p, _)| (n == &root).then_some(*p)) + .collect(); + for p in context.signature(root).unwrap_or_default().input_ports() { + if !got_inputs.contains(&p) { + self.0.push((root, p, PartialValue::Top)); + } } } // Note/TODO, if analysis is running on a subregion then we should do similar From a5d987c9b42f8f452ff69a68ec4627b10076e182 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 11:50:18 +0000 Subject: [PATCH 243/281] doc fixes, rename to run_library --- hugr-passes/src/dataflow/datalog.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 9efbdc041..74133a5bd 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -15,10 +15,11 @@ use super::{partial_from_const, AbstractValue, AnalysisResults, DFContext, Parti type PV = PartialValue; +#[allow(rustdoc::private_intra_doc_links)] /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] to produce [AnalysisResults] +/// 3. Call [Self::run] or [Self::run_library] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); /// derived-Default requires the context to be Defaultable, which is unnecessary @@ -102,10 +103,10 @@ impl Machine { } /// Run the analysis (iterate until a lattice fixpoint is reached), - /// for a [Module]-rooted Hugr where all functions are assumed callable + /// for a [Module](OpType::Module)-rooted Hugr where all functions are assumed callable /// (from a client) with any arguments. /// The context passed in allows interpretation of leaf operations. - pub fn run_lib>(mut self, context: C) -> AnalysisResults { + pub fn run_library>(mut self, context: C) -> AnalysisResults { let root = context.root(); if !context.get_optype(root).is_module() { panic!("Hugr not Module-rooted") From 3e718fdcc95ee1865d8a6e7ed22ed731ac2beace Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 11:52:46 +0000 Subject: [PATCH 244/281] Add PartialValue::contains_bottom, also row_contains_bottom --- hugr-passes/src/dataflow.rs | 8 ++++++++ hugr-passes/src/dataflow/partial_value.rs | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 3769eaced..f6e710f66 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -116,5 +116,13 @@ fn partial_from_const<'a, V>( } } +/// A row of inputs to a node contains bottom (can't happen, the node +/// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). +pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( + elements: impl IntoIterator>, +) -> bool { + elements.into_iter().any(PartialValue::contains_bottom) +} + #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 992a72444..cd0b1fb29 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -8,6 +8,8 @@ use std::collections::HashMap; use std::hash::{Hash, Hasher}; use thiserror::Error; +use super::row_contains_bottom; + /// Trait for an underlying domain of abstract values which can form the *elements* of a /// [PartialValue] and thus be used in dataflow analysis. pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { @@ -181,6 +183,13 @@ impl PartialSum { num_elements: v.len(), }) } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type @@ -352,6 +361,18 @@ impl PartialValue { Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } + + /// A value contains bottom means that it cannot occur during execution + /// - it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } } impl TryFrom> for Value { From 497686ae927a5ded7b4401fd36cc8a7081392e57 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 12:02:48 +0000 Subject: [PATCH 245/281] Don't call interpret_leaf_op if row_contains_bottom --- hugr-passes/src/dataflow/datalog.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 74133a5bd..e33494ac4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -11,7 +11,10 @@ use hugr_core::ops::{DataflowParent, NamedOp, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{partial_from_const, AbstractValue, AnalysisResults, DFContext, PartialValue}; +use super::{ + partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, + PartialValue, +}; type PV = PartialValue; @@ -378,20 +381,19 @@ fn propagate_leaf_op( )) } OpType::ExtensionOp(e) => { - // Interpret op using DFContext - let init = if ins.iter().contains(&PartialValue::Bottom) { + Some(ValueRow::from_iter(if row_contains_bottom(ins) { // So far we think one or more inputs can't happen. // So, don't pollute outputs with Top, and wait for better knowledge of inputs. - PartialValue::Bottom + vec![PartialValue::Bottom; num_outs] } else { - // If we can't figure out anything about the outputs, assume nothing (they still happen!) - PartialValue::Top - }; - let mut outs = vec![init; num_outs]; - // It might be nice to convert these to [(IncomingPort, Value)], or some concrete value, - // for the context, but PV contains more information, and try_into_value may fail. - ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); - Some(ValueRow::from_iter(outs)) + // Interpret op using DFContext + // Default to Top i.e. can't figure out anything about the outputs + let mut outs = vec![PartialValue::Top; num_outs]; + // It might be nice to convert `ins`` to [(IncomingPort, Value)], or some concrete value, + // for the context, but PV contains more information, and try_into_value may fail. + ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); + outs + })) } o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" } From e34c7bef059295be4bc439ec5db35cd4b92eeded Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 12:25:44 +0000 Subject: [PATCH 246/281] Use row_contains_bottom for CFG+DFG, and augment unpack_first(=>_no_bottom) --- hugr-passes/src/dataflow/datalog.rs | 17 +++++++++++------ hugr-passes/src/dataflow/value_row.rs | 15 +++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index e33494ac4..9a0426ae6 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -194,7 +194,10 @@ pub(super) fn run_datalog>( dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - input_child(dfg, i), in_wire_value(dfg, p, v); + input_child(dfg, i), + node_in_value_row(dfg, row), + if !row_contains_bottom(&row[..]), // Treat the DFG as a scheduling barrier + for (p, v) in row[..].iter().enumerate(); out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), output_child(dfg, o), in_wire_value(o, p, v); @@ -213,7 +216,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ...and select just what's possible for CONTINUE_TAG, if anything - if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop @@ -222,7 +225,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ... and select just what's possible for BREAK_TAG, if anything - if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), for (out_p, v) in fields.enumerate(); // Conditional -------------------- @@ -239,7 +242,7 @@ pub(super) fn run_datalog>( input_child(case, i_node), node_in_value_row(cond, in_row), let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), - if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + if let Some(fields) = in_row.unpack_first_no_bottom(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); // outputs of case nodes propagate to outputs of conditional *if* case reachable @@ -274,7 +277,9 @@ pub(super) fn run_datalog>( cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(), input_child(entry, i_node), - in_wire_value(cfg, p, v); + node_in_value_row(cfg, row), + if !row_contains_bottom(&row[..]), + for (p, v) in row[..].iter().enumerate(); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of : @@ -293,7 +298,7 @@ pub(super) fn run_datalog>( output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), - if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); // Call -------------------- diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 0d8bc15a6..fc1d66818 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -8,7 +8,7 @@ use std::{ use ascent::{lattice::BoundedLattice, Lattice}; use itertools::zip_eq; -use super::{AbstractValue, PartialValue}; +use super::{row_contains_bottom, AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] pub(super) struct ValueRow(Vec>); @@ -25,16 +25,19 @@ impl ValueRow { r } - /// The first value in this ValueRow must be a sum; - /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, - /// then appending the rest of the values in this row. - pub fn unpack_first( + /// If the first value in this ValueRow is a sum, that might contain + /// the specified tag, then unpack the elements of that tag, append the rest + /// of this ValueRow, and if none of the elements of that row [contain bottom](PartialValue::contains_bottom), + /// return it. + /// Otherwise (if no such tag, or values contain bottom), return None. + pub fn unpack_first_no_bottom( &self, variant: usize, len: usize, ) -> Option>> { let vals = self[0].variant_values(variant, len)?; - Some(vals.into_iter().chain(self.0[1..].to_owned())) + (!row_contains_bottom(vals.iter().chain(self.0[1..].iter()))) + .then(|| vals.into_iter().chain(self.0[1..].to_owned())) } } From 69a69f3be88ebbfc947a5bbd4e13107fdb5da59c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 13:30:36 +0000 Subject: [PATCH 247/281] run_library => publish_function --- hugr-passes/src/dataflow/datalog.rs | 75 ++++++++++++++--------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 9a0426ae6..7a1d866b6 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{DataflowParent, NamedOp, OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowParent, FuncDefn, NamedOp, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -18,11 +18,12 @@ use super::{ type PV = PartialValue; -#[allow(rustdoc::private_intra_doc_links)] /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] -/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] or [Self::run_library] to produce [AnalysisResults] +/// 2. (Optionally) For [Module](OpType::Module)-rooted Hugrs, zero or more calls +/// to [Self::publish_function] +// or [Self::prepopulate_wire] with initial values +/// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); /// derived-Default requires the context to be Defaultable, which is unnecessary @@ -44,39 +45,35 @@ impl Machine { /// Run the analysis (iterate until a lattice fixpoint is reached), /// given initial values for some of the root node inputs. For a - /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. + /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"` + /// (it is an error if inputs are provided and there is no `"main"``). /// The context passed in allows interpretation of leaf operations. - /// - /// [Module]: OpType::Module pub fn run>( mut self, context: C, in_values: impl IntoIterator)>, ) -> AnalysisResults { + let mut in_values = in_values.into_iter(); let root = context.root(); // Some nodes do not accept values as dataflow inputs - for these // we must find the corresponding Output node. - let out_node_parent = match context.get_optype(root) { - OpType::Module(_) => Some( - context - .children(root) - .find(|n| { - context - .get_optype(*n) - .as_func_defn() - .is_some_and(|f| f.name() == "main") - }) - .expect("Module must contain a 'main' function to be analysed"), - ), + let input_node_parent = match context.get_optype(root) { + OpType::Module(_) => { + let main = find_func(&*context, "main"); + if main.is_none() && in_values.next().is_some() { + panic!("Cannot give inputs to module with no 'main'"); + } + main.map(|(n, _)| n) + } OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => Some(root), // Could also do Dfg above, but ok here too: _ => None, // Just feed into node inputs }; - // Now write values onto Input node out-wires or Outputs. // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis // (Consider: for a conditional that selects *either* the unknown input *or* value V, // analysis must produce Top == we-know-nothing, not `V` !) - if let Some(p) = out_node_parent { + if let Some(p) = input_node_parent { + // Put values onto out-wires of Input node let [inp, _] = context.get_io(p).unwrap(); let mut vals = vec![PartialValue::Top; context.signature(inp).unwrap().output_types().len()]; @@ -87,6 +84,7 @@ impl Machine { self.prepopulate_wire(&*context, Wire::new(inp, i), v); } } else { + // Put values onto in-wires of root node self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); let got_inputs: HashSet<_, RandomState> = self @@ -105,27 +103,28 @@ impl Machine { run_datalog(context, self.0) } - /// Run the analysis (iterate until a lattice fixpoint is reached), - /// for a [Module](OpType::Module)-rooted Hugr where all functions are assumed callable - /// (from a client) with any arguments. - /// The context passed in allows interpretation of leaf operations. - pub fn run_library>(mut self, context: C) -> AnalysisResults { - let root = context.root(); - if !context.get_optype(root).is_module() { - panic!("Hugr not Module-rooted") + /// For [Module](OpType::Module)-rooted Hugrs, mark a FuncDefn that is a child + /// of the root node as externally callable, i.e. with any arguments. + pub fn publish_function>(&mut self, context: C, name: &str) { + let (n, fd) = find_func(&*context, name).unwrap(); + let [inp, _] = context.get_io(n).unwrap(); + for p in 0..fd.inner_signature().input_count() { + self.prepopulate_wire(&*context, Wire::new(inp, p), PartialValue::Top); } - for n in context.children(root) { - if let Some(fd) = context.get_optype(n).as_func_defn() { - let [inp, _] = context.get_io(n).unwrap(); - for p in 0..fd.inner_signature().input_count() { - self.prepopulate_wire(&*context, Wire::new(inp, p), PartialValue::Top); - } - } - } - run_datalog(context, self.0) } } +fn find_func<'a>(h: &'a impl HugrView, name: &str) -> Option<(Node, &'a FuncDefn)> { + assert!(h.get_optype(h.root()).is_module()); + h.children(h.root()) + .filter_map(|n| { + h.get_optype(n) + .as_func_defn() + .and_then(|f| (f.name() == name).then_some((n, f))) + }) + .next() +} + pub(super) fn run_datalog>( ctx: C, in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, From 57ac432c1e35a1f6c57844fc21ae5422c7adc564 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 14:58:47 +0000 Subject: [PATCH 248/281] Drop publish_function, pub prepopulate_wire --- hugr-passes/src/dataflow/datalog.rs | 55 +++++++++++------------------ 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 7a1d866b6..b29185827 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{DataflowParent, FuncDefn, NamedOp, OpTrait, OpType, TailLoop}; +use hugr_core::ops::{NamedOp, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -20,9 +20,10 @@ type PV = PartialValue; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] -/// 2. (Optionally) For [Module](OpType::Module)-rooted Hugrs, zero or more calls -/// to [Self::publish_function] -// or [Self::prepopulate_wire] with initial values +/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] with initial values. +/// For example, for a [Module](OpType::Module)-rooted Hugr, each externally-callable +/// [FuncDefn](OpType::FuncDefn) should have the out-wires from its [Input](OpType::Input) +/// node prepopulated with [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); @@ -34,9 +35,8 @@ impl Default for Machine { } impl Machine { - // Provide initial values for a wire - these will be `join`d with any computed. - // pub(crate) so can be used for tests. - pub(crate) fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue) { + /// Provide initial values for a wire - these will be `join`d with any computed. + pub fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue) { self.0.extend( h.linked_inputs(w.node(), w.source()) .map(|(n, inp)| (n, inp, v.clone())), @@ -45,9 +45,12 @@ impl Machine { /// Run the analysis (iterate until a lattice fixpoint is reached), /// given initial values for some of the root node inputs. For a - /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"` - /// (it is an error if inputs are provided and there is no `"main"``). + /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. /// The context passed in allows interpretation of leaf operations. + /// + /// # Panics + /// May panic in various ways if the Hugr is invalid; + /// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`. pub fn run>( mut self, context: C, @@ -56,14 +59,19 @@ impl Machine { let mut in_values = in_values.into_iter(); let root = context.root(); // Some nodes do not accept values as dataflow inputs - for these - // we must find the corresponding Output node. + // we must find the corresponding Input node. let input_node_parent = match context.get_optype(root) { OpType::Module(_) => { - let main = find_func(&*context, "main"); + let main = context.children(root).find(|n| { + context + .get_optype(*n) + .as_func_defn() + .is_some_and(|f| f.name() == "main") + }); if main.is_none() && in_values.next().is_some() { panic!("Cannot give inputs to module with no 'main'"); } - main.map(|(n, _)| n) + main } OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => Some(root), // Could also do Dfg above, but ok here too: @@ -84,7 +92,7 @@ impl Machine { self.prepopulate_wire(&*context, Wire::new(inp, i), v); } } else { - // Put values onto in-wires of root node + // Put values onto in-wires of root node, datalog will do the rest self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); let got_inputs: HashSet<_, RandomState> = self @@ -102,27 +110,6 @@ impl Machine { // for any nonlocal edges providing values from outside the region. run_datalog(context, self.0) } - - /// For [Module](OpType::Module)-rooted Hugrs, mark a FuncDefn that is a child - /// of the root node as externally callable, i.e. with any arguments. - pub fn publish_function>(&mut self, context: C, name: &str) { - let (n, fd) = find_func(&*context, name).unwrap(); - let [inp, _] = context.get_io(n).unwrap(); - for p in 0..fd.inner_signature().input_count() { - self.prepopulate_wire(&*context, Wire::new(inp, p), PartialValue::Top); - } - } -} - -fn find_func<'a>(h: &'a impl HugrView, name: &str) -> Option<(Node, &'a FuncDefn)> { - assert!(h.get_optype(h.root()).is_module()); - h.children(h.root()) - .filter_map(|n| { - h.get_optype(n) - .as_func_defn() - .and_then(|f| (f.name() == name).then_some((n, f))) - }) - .next() } pub(super) fn run_datalog>( From f9a9f2446bd7fac321c02a707e452d5a06781edb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 15:10:24 +0000 Subject: [PATCH 249/281] ValueRow::single_known => singleton, set --- hugr-passes/src/dataflow/datalog.rs | 12 +++--------- hugr-passes/src/dataflow/value_row.rs | 12 +++++++----- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b29185827..69e26f029 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -163,7 +163,7 @@ pub(super) fn run_datalog>( // Assemble node_in_value_row from in_wire_value's node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); - node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(ctx.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire_value(n, p, v); // Interpret leaf ops out_wire_value(n, p, v) <-- @@ -351,11 +351,7 @@ fn propagate_leaf_op( .unwrap() .0; let const_val = ctx.get_optype(const_node).as_const().unwrap().value(); - Some(ValueRow::single_known( - 1, - 0, - partial_from_const(ctx, n, const_val), - )) + Some(ValueRow::singleton(partial_from_const(ctx, n, const_val))) } OpType::LoadFunction(load_op) => { assert!(ins.is_empty()); // static edge @@ -364,9 +360,7 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::single_known( - 1, - 0, + Some(ValueRow::singleton( ctx.value_from_function(func_node, &load_op.type_args) .map_or(PV::Top, PV::Value), )) diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index fc1d66818..9360f36e3 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -18,11 +18,13 @@ impl ValueRow { Self(vec![PartialValue::bottom(); len]) } - pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + *self.0.get_mut(idx).unwrap() = v; + self + } + + pub fn singleton(v: PartialValue) -> Self { + Self(vec![v]) } /// If the first value in this ValueRow is a sum, that might contain From 9cc368df844a1c4e7c65aee76fc9eb3906ea8044 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 15:23:30 +0000 Subject: [PATCH 250/281] try_join / try_meet return extra bool --- hugr-passes/src/dataflow/partial_value.rs | 29 ++++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index cd0b1fb29..8e394ec43 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -14,23 +14,26 @@ use super::row_contains_bottom; /// [PartialValue] and thus be used in dataflow analysis. pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { /// Computes the join of two values (i.e. towards `Top``), if this is representable - /// within the underlying domain. - /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Top]). + /// within the underlying domain. Return the new value, and whether this is different from + /// the old `self`. /// - /// The default checks equality between `self` and `other` and returns `self` if + /// If the join is not representable, return `None` - i.e., we should use [PartialValue::Top]. + /// + /// The default checks equality between `self` and `other` and returns `(self,false)` if /// the two are identical, otherwise `None`. - fn try_join(self, other: Self) -> Option { - (self == other).then_some(self) + fn try_join(self, other: Self) -> Option<(Self, bool)> { + (self == other).then_some((self, false)) } /// Computes the meet of two values (i.e. towards `Bottom`), if this is representable - /// within the underlying domain. - /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Bottom]). + /// within the underlying domain. Return the new value, and whether this is different from + /// the old `self`. + /// If the meet is not representable, return `None` - i.e., we should use [PartialValue::Bottom]. /// - /// The default checks equality between `self` and `other` and returns `self` if + /// The default checks equality between `self` and `other` and returns `(self, false)` if /// the two are identical, otherwise `None`. - fn try_meet(self, other: Self) -> Option { - (self == other).then_some(self) + fn try_meet(self, other: Self) -> Option<(Self, bool)> { + (self == other).then_some((self, false)) } } @@ -398,8 +401,7 @@ impl Lattice for PartialValue { true } (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { - Some(h3) => { - let ch = h3 != *h1; + Some((h3, ch)) => { *self = Self::Value(h3); ch } @@ -441,8 +443,7 @@ impl Lattice for PartialValue { true } (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_meet(h2) { - Some(h3) => { - let ch = h3 != *h1; + Some((h3, ch)) => { *self = Self::Value(h3); ch } From a61fbdb73f8895856fd31724c90018e27c794e5f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 15:31:47 +0000 Subject: [PATCH 251/281] shorten/common-up meet_mut + join_mut --- hugr-passes/src/dataflow/partial_value.rs | 36 ++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 8e394ec43..e6142a9f2 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -400,16 +400,14 @@ impl Lattice for PartialValue { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { - Some((h3, ch)) => { - *self = Self::Value(h3); - ch - } - None => { - *self = Self::Top; - true - } - }, + (Self::Value(h1), Self::Value(h2)) => { + let (nv, ch) = match h1.clone().try_join(h2) { + Some((h3, b)) => (Self::Value(h3), b), + None => (Self::Top, true), + }; + *self = nv; + ch + } (Self::PartialSum(_), Self::PartialSum(ps2)) => { let Self::PartialSum(ps1) = self else { unreachable!() @@ -442,16 +440,14 @@ impl Lattice for PartialValue { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_meet(h2) { - Some((h3, ch)) => { - *self = Self::Value(h3); - ch - } - None => { - *self = Self::Bottom; - true - } - }, + (Self::Value(h1), Self::Value(h2)) => { + let (h3, ch) = match h1.clone().try_meet(h2) { + Some((h3, ch)) => (Self::Value(h3), ch), + None => (Self::Bottom, true), + }; + *self = h3; + ch + } (Self::PartialSum(_), Self::PartialSum(ps2)) => { let Self::PartialSum(ps1) = self else { unreachable!() From 24cce0e01a656f8d1d49a086393bf30612070bb7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 16:44:21 +0000 Subject: [PATCH 252/281] try_into_value: change bounds TryFrom -> TryInto; rename =>try_into_sum --- hugr-passes/src/dataflow/partial_value.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index e6142a9f2..21c36668d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -155,12 +155,13 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_value]. - pub fn try_into_value( + pub fn try_into_sum( self, typ: &Type, ) -> Result, ExtractValueError> where - V2: TryFrom + TryFrom, Error = SE>, + V: TryInto, + Sum: TryInto, { let Ok((k, v)) = self.0.iter().exactly_one() else { return Err(ExtractValueError::MultipleVariants(self)); @@ -351,15 +352,18 @@ impl PartialValue { /// incorrect), or if that [Sum] could not be converted into a `V2`. pub fn try_into_value(self, typ: &Type) -> Result> where - V2: TryFrom + TryFrom, Error = SE>, + V: TryInto, + Sum: TryInto, { match self { - Self::Value(v) => V2::try_from(v.clone()) + Self::Value(v) => v + .clone() + .try_into() .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), - Self::PartialSum(ps) => { - let v = ps.try_into_value(typ)?; - V2::try_from(v).map_err(ExtractValueError::CouldNotBuildSum) - } + Self::PartialSum(ps) => ps + .try_into_sum(typ)? + .try_into() + .map_err(ExtractValueError::CouldNotBuildSum), Self::Top => Err(ExtractValueError::ValueIsTop), Self::Bottom => Err(ExtractValueError::ValueIsBottom), } From a59076629e08ed71127d010debd27191e661042d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 16:56:12 +0000 Subject: [PATCH 253/281] Avoid a clone in try_into_sum --- hugr-passes/src/dataflow/partial_value.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 21c36668d..989d40640 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -163,17 +163,18 @@ impl PartialSum { V: TryInto, Sum: TryInto, { - let Ok((k, v)) = self.0.iter().exactly_one() else { + if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); - }; + } + let (tag, v) = self.0.into_iter().exactly_one().unwrap(); if let TypeEnum::Sum(st) = typ.as_type_enum() { - if let Some(r) = st.get_variant(*k) { + if let Some(r) = st.get_variant(tag) { if let Ok(r) = TypeRow::try_from(r.clone()) { if v.len() == r.len() { return Ok(Sum { - tag: *k, + tag, values: zip_eq(v, r.iter()) - .map(|(v, t)| v.clone().try_into_value(t)) + .map(|(v, t)| v.try_into_value(t)) .collect::, _>>()?, st: st.clone(), }); @@ -183,7 +184,7 @@ impl PartialSum { } Err(ExtractValueError::BadSumType { typ: typ.clone(), - tag: *k, + tag, num_elements: v.len(), }) } From 2b2c461397cdea43e1f8b14395aef5a8e648b512 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 15:45:47 +0000 Subject: [PATCH 254/281] Optimize+shorten join_mut / meet_mut via std::mem::swap --- hugr-passes/src/dataflow/partial_value.rs | 102 ++++++++++------------ 1 file changed, 44 insertions(+), 58 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 989d40640..a9933e586 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -394,81 +394,67 @@ impl TryFrom> for Value { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - match (&*self, other) { - (Self::Top, _) => false, - (_, other @ Self::Top) => { - *self = other; - true + let mut old_self = Self::Top; // Good default result + std::mem::swap(self, &mut old_self); + match (old_self, other) { + (Self::Top, _) => false, // result is Top + (_, Self::Top) => true, // result is Top + (old, Self::Bottom) => { + *self = old; // reinstate + false } - (_, Self::Bottom) => false, (Self::Bottom, other) => { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => { - let (nv, ch) = match h1.clone().try_join(h2) { - Some((h3, b)) => (Self::Value(h3), b), - None => (Self::Top, true), - }; - *self = nv; - ch - } - (Self::PartialSum(_), Self::PartialSum(ps2)) => { - let Self::PartialSum(ps1) = self else { - unreachable!() - }; - match ps1.try_join_mut(ps2) { - Ok(ch) => ch, - Err(_) => { - *self = Self::Top; - true - } + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { + Some((h3, b)) => { + *self = Self::Value(h3); + b } - } - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - *self = Self::Top; - true - } + None => true, // result is Top + }, + (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { + Ok(ch) => { + *self = Self::PartialSum(ps1); + ch + } + Err(_) => true, // result is Top + }, + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => true, // result is Top } } fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - match (&*self, other) { - (Self::Bottom, _) => false, - (_, other @ Self::Bottom) => { - *self = other; - true + let mut old_self = Self::Bottom; // Good default result + std::mem::swap(self, &mut old_self); + match (old_self, other) { + (Self::Bottom, _) => false, // result is Bottom + (_, Self::Bottom) => true, // result is Bottom + (old, Self::Top) => { + *self = old; //reinstate + false } - (_, Self::Top) => false, (Self::Top, other) => { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => { - let (h3, ch) = match h1.clone().try_meet(h2) { - Some((h3, ch)) => (Self::Value(h3), ch), - None => (Self::Bottom, true), - }; - *self = h3; - ch - } - (Self::PartialSum(_), Self::PartialSum(ps2)) => { - let Self::PartialSum(ps1) = self else { - unreachable!() - }; - match ps1.try_meet_mut(ps2) { - Ok(ch) => ch, - Err(_) => { - *self = Self::Bottom; - true - } + (Self::Value(h1), Self::Value(h2)) => match h1.try_meet(h2) { + Some((h3, ch)) => { + *self = Self::Value(h3); + ch } - } - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - *self = Self::Bottom; - true - } + None => true, //result is Bottom + }, + (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { + Ok(ch) => { + *self = Self::PartialSum(ps1); + ch + } + Err(_) => true, + }, + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => true, // result is Bottom } } } From 124718d08758b86ac0106a41cceb04e3917468a4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 17:00:50 +0000 Subject: [PATCH 255/281] refactor join_mut / meet_mut again, common-up assignment --- hugr-passes/src/dataflow/partial_value.rs | 76 +++++++++-------------- 1 file changed, 28 insertions(+), 48 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index a9933e586..4bf5e927f 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -394,68 +394,48 @@ impl TryFrom> for Value { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - let mut old_self = Self::Top; // Good default result + let mut old_self = Self::Top; std::mem::swap(self, &mut old_self); - match (old_self, other) { - (Self::Top, _) => false, // result is Top - (_, Self::Top) => true, // result is Top - (old, Self::Bottom) => { - *self = old; // reinstate - false - } - (Self::Bottom, other) => { - *self = other; - true - } + let (res, ch) = match (old_self, other) { + (old @ Self::Top, _) | (old, Self::Bottom) => (old, false), + (_, other @ Self::Top) | (Self::Bottom, other) => (other, true), (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { - Some((h3, b)) => { - *self = Self::Value(h3); - b - } - None => true, // result is Top + Some((h3, b)) => (Self::Value(h3), b), + None => (Self::Top, true), }, (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { - Ok(ch) => { - *self = Self::PartialSum(ps1); - ch - } - Err(_) => true, // result is Top + Ok(ch) => (Self::PartialSum(ps1), ch), + Err(_) => (Self::Top, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => true, // result is Top - } + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + (Self::Top, true) + } + }; + *self = res; + ch } fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - let mut old_self = Self::Bottom; // Good default result + let mut old_self = Self::Bottom; std::mem::swap(self, &mut old_self); - match (old_self, other) { - (Self::Bottom, _) => false, // result is Bottom - (_, Self::Bottom) => true, // result is Bottom - (old, Self::Top) => { - *self = old; //reinstate - false - } - (Self::Top, other) => { - *self = other; - true - } + let (res, ch) = match (old_self, other) { + (old @ Self::Bottom, _) | (old, Self::Top) => (old, false), + (_, other @ Self::Bottom) | (Self::Top, other) => (other, true), (Self::Value(h1), Self::Value(h2)) => match h1.try_meet(h2) { - Some((h3, ch)) => { - *self = Self::Value(h3); - ch - } - None => true, //result is Bottom + Some((h3, ch)) => (Self::Value(h3), ch), + None => (Self::Bottom, true), }, (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { - Ok(ch) => { - *self = Self::PartialSum(ps1); - ch - } - Err(_) => true, + Ok(ch) => (Self::PartialSum(ps1), ch), + Err(_) => (Self::Bottom, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => true, // result is Bottom - } + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + (Self::Bottom, true) + } + }; + *self = res; + ch } } From 93b1f4d64501f3182b40415939571e1662cbfbde Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 17:19:21 +0000 Subject: [PATCH 256/281] clippy --- hugr-passes/src/dataflow/datalog.rs | 3 +-- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 69e26f029..ab8c7c6f3 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -93,8 +93,7 @@ impl Machine { } } else { // Put values onto in-wires of root node, datalog will do the rest - self.0 - .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + self.0.extend(in_values.map(|(p, v)| (root, p, v))); let got_inputs: HashSet<_, RandomState> = self .0 .iter() diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 4bf5e927f..2a3507144 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -370,8 +370,8 @@ impl PartialValue { } } - /// A value contains bottom means that it cannot occur during execution - /// - it may be an artefact during bootstrapping of the analysis, or else + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else /// the value depends upon a `panic` or a loop that /// [never terminates](super::TailLoopTermination::NeverBreaks). pub fn contains_bottom(&self) -> bool { From 731a3b05bd7967af7bd5f06782c9e13b0d326588 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 17:28:39 +0000 Subject: [PATCH 257/281] doclinks --- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 2a3507144..60a3ae514 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -344,12 +344,12 @@ impl PartialValue { /// Turns this instance into some "concrete" value type `V2`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by - /// [PartialSum::try_into_value]. + /// [PartialSum::try_into_sum]. /// /// # Errors /// /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) - /// that could not be converted into a [Sum] by [PartialSum::try_into_value] (e.g. if `typ` is + /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. pub fn try_into_value(self, typ: &Type) -> Result> where From 5b51434b3af08d3692508a2340112c4653881394 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 17:32:22 +0000 Subject: [PATCH 258/281] missing docs --- hugr-passes/src/const_fold/value_handle.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index e984c41fe..e1e69ae63 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -49,9 +49,12 @@ pub enum ValueHandle { Hashable(HashedConst), /// Either a [Value::Extension] that can't be hashed, or a [Value::Function]. Unhashable { + /// The node (i.e. a [Const](hugr_core::ops::Const)) containing the constant node: Node, + /// Indices within [Value::Sum]s containing the unhashable [Self::Unhashable::leaf] fields: Vec, - value: Either, Arc>, + /// The unhashable [Value::Extension] or [Value::Function] + leaf: Either, Arc>, }, } @@ -76,7 +79,7 @@ impl ValueHandle { Self::Unhashable { node, fields, - value: Either::Left(arc), + leaf: Either::Left(arc), }, Self::Hashable, ) @@ -88,7 +91,7 @@ impl ValueHandle { Self::Unhashable { node, fields, - value: Either::Right(Arc::from(val)), + leaf: Either::Right(Arc::from(val)), } } } @@ -103,12 +106,12 @@ impl PartialEq for ValueHandle { Self::Unhashable { node: n1, fields: f1, - value: _, + leaf: _, }, Self::Unhashable { node: n2, fields: f2, - value: _, + leaf: _, }, ) => { // If the keys are equal, we return true since the values must have the @@ -132,7 +135,7 @@ impl Hash for ValueHandle { ValueHandle::Unhashable { node, fields, - value: _, + leaf: _, } => { node.hash(state); fields.hash(state); @@ -148,13 +151,13 @@ impl From for Value { match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable { - value: Either::Left(val), + leaf: Either::Left(val), .. } => Value::Extension { e: Arc::try_unwrap(val).unwrap_or_else(|a| a.as_ref().clone()), }, ValueHandle::Unhashable { - value: Either::Right(hugr), + leaf: Either::Right(hugr), .. } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) From 7040e83d039028fd424e46eb4cd5c7e1f0a7bdb3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 21:46:02 +0000 Subject: [PATCH 259/281] prepopulate_df_inputs --- hugr-passes/src/dataflow/datalog.rs | 44 ++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index ab8c7c6f3..e71e34896 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -20,10 +20,11 @@ type PV = PartialValue; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] -/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] with initial values. -/// For example, for a [Module](OpType::Module)-rooted Hugr, each externally-callable -/// [FuncDefn](OpType::FuncDefn) should have the out-wires from its [Input](OpType::Input) -/// node prepopulated with [PartialValue::Top]. +/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] and/or +/// [Self::prepopulate_df_inputs] with initial values. +/// For example, to analyse a [Module](OpType::Module)-rooted Hugr as a library, +/// [Self::prepopulate_df_inputs] can be used on each externally-callable +/// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); @@ -43,6 +44,26 @@ impl Machine { ); } + /// Provide initial values for the inputs to a [DataflowParent](hugr_core::ops::OpTag::DataflowParent) + /// (that is, values on the wires leaving the [Input](OpType::Input) child thereof). + /// Any out-ports of said same `Input` node, not given values by `in_values`, are set to [PartialValue::Top]. + pub fn prepopulate_df_inputs( + &mut self, + h: &impl HugrView, + parent: Node, + in_values: impl IntoIterator)>, + ) { + // Put values onto out-wires of Input node + let [inp, _] = h.get_io(parent).unwrap(); + let mut vals = vec![PartialValue::Top; h.signature(inp).unwrap().output_types().len()]; + for (ip, v) in in_values { + vals[ip.index()] = v; + } + for (i, v) in vals.into_iter().enumerate() { + self.prepopulate_wire(h, Wire::new(inp, i), v); + } + } + /// Run the analysis (iterate until a lattice fixpoint is reached), /// given initial values for some of the root node inputs. For a /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. @@ -81,16 +102,11 @@ impl Machine { // (Consider: for a conditional that selects *either* the unknown input *or* value V, // analysis must produce Top == we-know-nothing, not `V` !) if let Some(p) = input_node_parent { - // Put values onto out-wires of Input node - let [inp, _] = context.get_io(p).unwrap(); - let mut vals = - vec![PartialValue::Top; context.signature(inp).unwrap().output_types().len()]; - for (ip, v) in in_values { - vals[ip.index()] = v; - } - for (i, v) in vals.into_iter().enumerate() { - self.prepopulate_wire(&*context, Wire::new(inp, i), v); - } + self.prepopulate_df_inputs( + &*context, + p, + in_values.map(|(p, v)| (OutgoingPort::from(p.index()), v)), + ); } else { // Put values onto in-wires of root node, datalog will do the rest self.0.extend(in_values.map(|(p, v)| (root, p, v))); From 584327f36bfacf6e8e7c455aa84c84fb14a09f01 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 20 Nov 2024 13:52:42 +0000 Subject: [PATCH 260/281] Revert "Use row_contains_bottom for CFG+DFG, and augment unpack_first(=>_no_bottom)" This reverts commit e34c7bef059295be4bc439ec5db35cd4b92eeded. --- hugr-passes/src/dataflow/datalog.rs | 17 ++++++----------- hugr-passes/src/dataflow/value_row.rs | 15 ++++++--------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index e71e34896..b814b6440 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -195,10 +195,7 @@ pub(super) fn run_datalog>( dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - input_child(dfg, i), - node_in_value_row(dfg, row), - if !row_contains_bottom(&row[..]), // Treat the DFG as a scheduling barrier - for (p, v) in row[..].iter().enumerate(); + input_child(dfg, i), in_wire_value(dfg, p, v); out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), output_child(dfg, o), in_wire_value(o, p, v); @@ -217,7 +214,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ...and select just what's possible for CONTINUE_TAG, if anything - if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), + if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop @@ -226,7 +223,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ... and select just what's possible for BREAK_TAG, if anything - if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), + if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), for (out_p, v) in fields.enumerate(); // Conditional -------------------- @@ -243,7 +240,7 @@ pub(super) fn run_datalog>( input_child(case, i_node), node_in_value_row(cond, in_row), let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), - if let Some(fields) = in_row.unpack_first_no_bottom(*case_index, conditional.sum_rows[*case_index].len()), + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); // outputs of case nodes propagate to outputs of conditional *if* case reachable @@ -278,9 +275,7 @@ pub(super) fn run_datalog>( cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(), input_child(entry, i_node), - node_in_value_row(cfg, row), - if !row_contains_bottom(&row[..]), - for (p, v) in row[..].iter().enumerate(); + in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of : @@ -299,7 +294,7 @@ pub(super) fn run_datalog>( output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), - if let Some(fields) = out_in_row.unpack_first_no_bottom(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); // Call -------------------- diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 9360f36e3..50cf10318 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -8,7 +8,7 @@ use std::{ use ascent::{lattice::BoundedLattice, Lattice}; use itertools::zip_eq; -use super::{row_contains_bottom, AbstractValue, PartialValue}; +use super::{AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] pub(super) struct ValueRow(Vec>); @@ -27,19 +27,16 @@ impl ValueRow { Self(vec![v]) } - /// If the first value in this ValueRow is a sum, that might contain - /// the specified tag, then unpack the elements of that tag, append the rest - /// of this ValueRow, and if none of the elements of that row [contain bottom](PartialValue::contains_bottom), - /// return it. - /// Otherwise (if no such tag, or values contain bottom), return None. - pub fn unpack_first_no_bottom( + /// The first value in this ValueRow must be a sum; + /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, + /// then appending the rest of the values in this row. + pub fn unpack_first( &self, variant: usize, len: usize, ) -> Option>> { let vals = self[0].variant_values(variant, len)?; - (!row_contains_bottom(vals.iter().chain(self.0[1..].iter()))) - .then(|| vals.into_iter().chain(self.0[1..].to_owned())) + Some(vals.into_iter().chain(self.0[1..].to_owned())) } } From 409f377878af5c58252515160d53ddf9a39848ba Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 20 Nov 2024 15:23:33 +0000 Subject: [PATCH 261/281] Fix FuncDefn::name, add test_module --- hugr-passes/src/dataflow/datalog.rs | 2 +- hugr-passes/src/dataflow/test.rs | 69 ++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b814b6440..8dfd98081 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -87,7 +87,7 @@ impl Machine { context .get_optype(*n) .as_func_defn() - .is_some_and(|f| f.name() == "main") + .is_some_and(|f| f.name == "main") }); if main.is_none() && in_values.next().is_some() { panic!("Cannot give inputs to module with no 'main'"); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index c0fbf395a..dafdb8046 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,6 +1,6 @@ use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; +use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; use hugr_core::ops::TailLoop; @@ -483,3 +483,70 @@ fn test_region() { ); } } + +#[test] +fn test_module() { + let mut modb = ModuleBuilder::new(); + let leaf_fn = modb + .define_function("leaf", Signature::new_endo(type_row![BOOL_T; 2])) + .unwrap(); + let outs = leaf_fn.input_wires(); + let leaf_fn = leaf_fn.finish_with_outputs(outs).unwrap(); + + let mut f2 = modb + .define_function("f2", Signature::new(BOOL_T, type_row![BOOL_T; 2])) + .unwrap(); + let [inp] = f2.input_wires_arr(); + let cst_true = f2.add_load_value(Value::true_val()); + let f2_call = f2 + .call(&leaf_fn.handle(), &[], [inp, cst_true], &EMPTY_REG) + .unwrap(); + let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap(); + + let mut main = modb + .define_function("main", Signature::new(BOOL_T, type_row![BOOL_T; 2])) + .unwrap(); + let [inp] = main.input_wires_arr(); + let cst_false = main.add_load_value(Value::false_val()); + let main_call = main + .call(&leaf_fn.handle(), &[], [inp, cst_false], &EMPTY_REG) + .unwrap(); + main.finish_with_outputs(main_call.outputs()).unwrap(); + let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); + let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); + + let results_just_main = Machine::default().run(TestContext(&hugr), [(0.into(), pv_true())]); + assert_eq!( + results_just_main.read_out_wire(Wire::new(f2_inp, 0)), + Some(PartialValue::Bottom) + ); + for call in [f2_call, main_call] { + // The first output of the Call comes from `main` because no value was fed in from f2 + assert_eq!( + results_just_main.read_out_wire(Wire::new(call.node(), 0)), + Some(pv_true().into()) + ); + // (Without reachability) the second output of the Call is the join of the two constant inputs from the two calls + assert_eq!( + results_just_main.read_out_wire(Wire::new(call.node(), 1)), + Some(pv_true_or_false().into()) + ); + } + + let results_two_calls = { + let mut m = Machine::default(); + m.prepopulate_df_inputs(&hugr, f2.node(), [(0.into(), pv_true())]); + m.run(TestContext(&hugr), [(0.into(), pv_false())]) + }; + + for call in [f2_call, main_call] { + assert_eq!( + results_two_calls.read_out_wire(Wire::new(call.node(), 0)), + Some(pv_true_or_false().into()) + ); + assert_eq!( + results_two_calls.read_out_wire(Wire::new(call.node(), 1)), + Some(pv_true_or_false().into()) + ); + } +} From 39b754ad1806254c2ed2776780651333ff44e5a6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 20 Nov 2024 15:27:58 +0000 Subject: [PATCH 262/281] clippy --- hugr-passes/src/dataflow/datalog.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 8dfd98081..3b0571f95 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{NamedOp, OpTrait, OpType, TailLoop}; +use hugr_core::ops::{OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; From c076a2c9ae4f40a614f34722c0620bcb8ef7c52e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 21 Nov 2024 09:10:27 +0000 Subject: [PATCH 263/281] clippy test --- hugr-passes/src/dataflow/test.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index dafdb8046..e5d8d48ee 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -499,7 +499,7 @@ fn test_module() { let [inp] = f2.input_wires_arr(); let cst_true = f2.add_load_value(Value::true_val()); let f2_call = f2 - .call(&leaf_fn.handle(), &[], [inp, cst_true], &EMPTY_REG) + .call(leaf_fn.handle(), &[], [inp, cst_true], &EMPTY_REG) .unwrap(); let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap(); @@ -509,7 +509,7 @@ fn test_module() { let [inp] = main.input_wires_arr(); let cst_false = main.add_load_value(Value::false_val()); let main_call = main - .call(&leaf_fn.handle(), &[], [inp, cst_false], &EMPTY_REG) + .call(leaf_fn.handle(), &[], [inp, cst_false], &EMPTY_REG) .unwrap(); main.finish_with_outputs(main_call.outputs()).unwrap(); let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); @@ -524,12 +524,12 @@ fn test_module() { // The first output of the Call comes from `main` because no value was fed in from f2 assert_eq!( results_just_main.read_out_wire(Wire::new(call.node(), 0)), - Some(pv_true().into()) + Some(pv_true()) ); // (Without reachability) the second output of the Call is the join of the two constant inputs from the two calls assert_eq!( results_just_main.read_out_wire(Wire::new(call.node(), 1)), - Some(pv_true_or_false().into()) + Some(pv_true_or_false()) ); } @@ -542,11 +542,11 @@ fn test_module() { for call in [f2_call, main_call] { assert_eq!( results_two_calls.read_out_wire(Wire::new(call.node(), 0)), - Some(pv_true_or_false().into()) + Some(pv_true_or_false()) ); assert_eq!( results_two_calls.read_out_wire(Wire::new(call.node(), 1)), - Some(pv_true_or_false().into()) + Some(pv_true_or_false()) ); } } From 77b6739469c05ce0e49e064aa02d5abee3981ffe Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 23 Nov 2024 09:53:52 +0000 Subject: [PATCH 264/281] try_into_{value=>concrete} --- hugr-passes/src/dataflow/datalog.rs | 5 +++-- hugr-passes/src/dataflow/partial_value.rs | 25 ++++++++++------------- hugr-passes/src/dataflow/results.rs | 4 ++-- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 3b0571f95..2e690001e 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -384,8 +384,9 @@ fn propagate_leaf_op( // Interpret op using DFContext // Default to Top i.e. can't figure out anything about the outputs let mut outs = vec![PartialValue::Top; num_outs]; - // It might be nice to convert `ins`` to [(IncomingPort, Value)], or some concrete value, - // for the context, but PV contains more information, and try_into_value may fail. + // It might be nice to convert `ins` to [(IncomingPort, Value)], or some + // other concrete value, for the context, but PV contains more information, + // and try_into_concrete may fail. ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); outs })) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 60a3ae514..f2a497806 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -147,21 +147,18 @@ impl PartialSum { self.0.contains_key(&tag) } - /// Turns this instance into a [Sum] of some "concrete" value type `V2`, + /// Turns this instance into a [Sum] of some "concrete" value type `C`, /// *if* this PartialSum has exactly one possible tag. /// /// # Errors /// /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; - /// or if converting a child element failed via [PartialValue::try_into_value]. - pub fn try_into_sum( - self, - typ: &Type, - ) -> Result, ExtractValueError> + /// or if converting a child element failed via [PartialValue::try_into_concrete]. + pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> where - V: TryInto, - Sum: TryInto, + V: TryInto, + Sum: TryInto, { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); @@ -174,7 +171,7 @@ impl PartialSum { return Ok(Sum { tag, values: zip_eq(v, r.iter()) - .map(|(v, t)| v.try_into_value(t)) + .map(|(v, t)| v.try_into_concrete(t)) .collect::, _>>()?, st: st.clone(), }); @@ -198,7 +195,7 @@ impl PartialSum { } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type -/// via [PartialValue::try_into_value] or [PartialSum::try_into_value] +/// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] pub enum ExtractValueError { @@ -342,7 +339,7 @@ impl PartialValue { } } - /// Turns this instance into some "concrete" value type `V2`, *if* it is a single value, + /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by /// [PartialSum::try_into_sum]. /// @@ -351,10 +348,10 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_value(self, typ: &Type) -> Result> + pub fn try_into_concrete(self, typ: &Type) -> Result> where - V: TryInto, - Sum: TryInto, + V: TryInto, + Sum: TryInto, { match self { Self::Value(v) => v diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 21d6b13c0..50d55f22d 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -82,7 +82,7 @@ impl> AnalysisResults { /// /// # Errors /// `None` if the analysis did not produce a result for that wire - /// `Some(e)` if conversion to a concrete value failed with error `e`, see [PartialValue::try_into_value] + /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` /// /// # Panics /// @@ -100,7 +100,7 @@ impl> AnalysisResults { .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - v.try_into_value(&typ).map_err(Some) + v.try_into_concrete(&typ).map_err(Some) } } From caa888218a0e1db21e52e4310d500f0444d1a6cd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 23 Nov 2024 09:59:11 +0000 Subject: [PATCH 265/281] try_read_wire_{value=>concrete} --- hugr-passes/src/dataflow/results.rs | 11 ++++------- hugr-passes/src/dataflow/test.rs | 14 ++++++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 50d55f22d..d60b89e36 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -81,13 +81,10 @@ impl> AnalysisResults { /// [PartialValue::Value] or a [PartialValue::PartialSum] with a single possible tag.) /// /// # Errors - /// `None` if the analysis did not produce a result for that wire + /// `None` if the analysis did not produce a result for that wire, or if + /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` - /// - /// # Panics - /// - /// If a [Type](hugr_core::types::Type) for the specified wire could not be extracted from the Hugr - pub fn try_read_wire_value( + pub fn try_read_wire_concrete( &self, w: Wire, ) -> Result>> @@ -99,7 +96,7 @@ impl> AnalysisResults { .hugr() .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) - .unwrap(); + .ok_or(None)?; v.try_into_concrete(&typ).map_err(Some) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e5d8d48ee..b3ea4a04f 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -68,7 +68,7 @@ fn test_make_tuple() { let results = Machine::default().run(TestContext(hugr), []); - let x: Value = results.try_read_wire_value(v3).unwrap(); + let x: Value = results.try_read_wire_concrete(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); } @@ -84,9 +84,9 @@ fn test_unpack_tuple_const() { let results = Machine::default().run(TestContext(hugr), []); - let o1_r: Value = results.try_read_wire_value(o1).unwrap(); + let o1_r: Value = results.try_read_wire_concrete(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r: Value = results.try_read_wire_value(o2).unwrap(); + let o2_r: Value = results.try_read_wire_concrete(o2).unwrap(); assert_eq!(o2_r, Value::true_val()); } @@ -110,7 +110,7 @@ fn test_tail_loop_never_iterates() { let results = Machine::default().run(TestContext(hugr), []); - let o_r: Value = results.try_read_wire_value(tl_o).unwrap(); + let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::NeverContinues), @@ -298,9 +298,11 @@ fn test_conditional() { )); let results = Machine::default().run(TestContext(hugr), [(0.into(), arg_pv)]); - let cond_r1: Value = results.try_read_wire_value(cond_o1).unwrap(); + let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results.try_read_wire_value::(cond_o2).is_err()); + assert!(results + .try_read_wire_concrete::(cond_o2) + .is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); From d387cef28989b08def782b4e1392a9932b8c3613 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 10:19:44 +0000 Subject: [PATCH 266/281] fix BOOL_T -> bool_t() and types needing extensions --- hugr-passes/src/dataflow/test.rs | 59 +++++++++++++++++--------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index b3ea4a04f..7ea8c6eb0 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,13 +1,15 @@ use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; use hugr_core::ops::TailLoop; +use hugr_core::types::TypeRow; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ - prelude::{UnpackTuple, BOOL_T}, + prelude::{bool_t, UnpackTuple}, ExtensionSet, EMPTY_REG, }, ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, @@ -64,7 +66,7 @@ fn test_make_tuple() { let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); let results = Machine::default().run(TestContext(hugr), []); @@ -77,10 +79,10 @@ fn test_unpack_tuple_const() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); let v = builder.add_load_value(Value::tuple([Value::false_val(), Value::true_val()])); let [o1, o2] = builder - .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v]) + .add_dataflow_op(UnpackTuple::new(vec![bool_t(); 2].into()), [v]) .unwrap() .outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); let results = Machine::default().run(TestContext(hugr), []); @@ -125,14 +127,14 @@ fn test_tail_loop_always_iterates() { Value::sum( TailLoop::CONTINUE_TAG, [], - SumType::new([type_row![], BOOL_T.into()]), + SumType::new([type_row![], bool_t().into()]), ) .unwrap(), ); let true_w = builder.add_load_value(Value::true_val()); let tlb = builder - .tail_loop_builder([], [(BOOL_T, true_w)], vec![BOOL_T].into()) + .tail_loop_builder([], [(bool_t(), true_w)], vec![bool_t()].into()) .unwrap(); // r_w has tag 0, so we always continue; @@ -166,14 +168,14 @@ fn test_tail_loop_two_iters() { let tlb = builder .tail_loop_builder_exts( [], - [(BOOL_T, false_w), (BOOL_T, true_w)], + [(bool_t(), false_w), (bool_t(), true_w)], type_row![], ExtensionSet::new(), ) .unwrap(); assert_eq!( tlb.loop_signature().unwrap().signature(), - Signature::new_endo(type_row![BOOL_T, BOOL_T]) + Signature::new_endo(vec![bool_t(); 2]) ); let [in_w1, in_w2] = tlb.input_wires_arr(); let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); @@ -197,9 +199,9 @@ fn test_tail_loop_two_iters() { #[test] fn test_tail_loop_containing_conditional() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); - let control_variants = vec![type_row![BOOL_T;2]; 2]; + let control_variants = vec![vec![bool_t(); 2].into(); 2]; let control_t = Type::new_sum(control_variants.clone()); - let body_out_variants = vec![control_t.clone().into(), type_row![BOOL_T; 2]]; + let body_out_variants = vec![TypeRow::from(control_t.clone()), vec![bool_t(); 2].into()]; let init = builder.add_load_value( Value::sum( @@ -211,7 +213,7 @@ fn test_tail_loop_containing_conditional() { ); let mut tlb = builder - .tail_loop_builder([(control_t, init)], [], type_row![BOOL_T; 2]) + .tail_loop_builder([(control_t, init)], [], vec![bool_t(); 2].into()) .unwrap(); let tl = tlb.loop_signature().unwrap().clone(); let [in_w] = tlb.input_wires_arr(); @@ -259,7 +261,7 @@ fn test_tail_loop_containing_conditional() { #[test] fn test_conditional() { - let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; + let variants = vec![type_row![], type_row![], bool_t().into()]; let cond_t = Type::new_sum(variants.clone()); let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap(); let [arg_w] = builder.input_wires_arr(); @@ -270,8 +272,8 @@ fn test_conditional() { let mut cond_builder = builder .conditional_builder( (variants, arg_w), - [(BOOL_T, true_w)], - type_row!(BOOL_T, BOOL_T), + [(bool_t(), true_w)], + vec![bool_t(); 2].into(), ) .unwrap(); // will be unreachable @@ -325,11 +327,11 @@ fn xor_and_cfg() -> Hugr { // T,F T,F - T,F // T,T T,T T,F F,T let mut builder = - CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap(); + CFGBuilder::new(Signature::new(vec![bool_t(); 2], vec![bool_t(); 2])).unwrap(); // entry (x, y) => (if x then A else B)(x=true, y) let entry = builder - .entry_builder(vec![type_row![]; 2], type_row![BOOL_T;2]) + .entry_builder(vec![type_row![]; 2], vec![bool_t(); 2].into()) .unwrap(); let [in_x, in_y] = entry.input_wires_arr(); let entry = entry.finish_with_outputs(in_x, [in_x, in_y]).unwrap(); @@ -337,9 +339,9 @@ fn xor_and_cfg() -> Hugr { // A(x==true, y) => (if y then B else X)(x, false) let mut a = builder .block_builder( - type_row![BOOL_T; 2], + vec![bool_t(); 2].into(), vec![type_row![]; 2], - type_row![BOOL_T; 2], + vec![bool_t(); 2].into(), ) .unwrap(); let [in_x, in_y] = a.input_wires_arr(); @@ -348,7 +350,11 @@ fn xor_and_cfg() -> Hugr { // B(w, v) => X(v, w) let mut b = builder - .block_builder(type_row![BOOL_T; 2], [type_row![]], type_row![BOOL_T; 2]) + .block_builder( + vec![bool_t(); 2].into(), + [type_row![]], + vec![bool_t(); 2].into(), + ) .unwrap(); let [in_w, in_v] = b.input_wires_arr(); let [control] = b @@ -407,9 +413,9 @@ fn test_call( #[case] inp1: PartialValue, #[case] out: PartialValue, ) { - let mut builder = DFGBuilder::new(Signature::new_endo(type_row![BOOL_T; 2])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap(); let func_bldr = builder - .define_function("id", Signature::new_endo(BOOL_T)) + .define_function("id", Signature::new_endo(bool_t())) .unwrap(); let [v] = func_bldr.input_wires_arr(); let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); @@ -436,12 +442,11 @@ fn test_call( #[test] fn test_region() { - let mut builder = - DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T;2])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap(); let [in_w] = builder.input_wires_arr(); let cst_w = builder.add_load_const(Value::false_val()); let nested = builder - .dfg_builder(Signature::new_endo(type_row![BOOL_T; 2]), [in_w, cst_w]) + .dfg_builder(Signature::new_endo(vec![bool_t(); 2]), [in_w, cst_w]) .unwrap(); let nested_ins = nested.input_wires(); let nested = nested.finish_with_outputs(nested_ins).unwrap(); @@ -490,13 +495,13 @@ fn test_region() { fn test_module() { let mut modb = ModuleBuilder::new(); let leaf_fn = modb - .define_function("leaf", Signature::new_endo(type_row![BOOL_T; 2])) + .define_function("leaf", Signature::new_endo(vec![bool_t(); 2])) .unwrap(); let outs = leaf_fn.input_wires(); let leaf_fn = leaf_fn.finish_with_outputs(outs).unwrap(); let mut f2 = modb - .define_function("f2", Signature::new(BOOL_T, type_row![BOOL_T; 2])) + .define_function("f2", Signature::new(bool_t(), vec![bool_t(); 2])) .unwrap(); let [inp] = f2.input_wires_arr(); let cst_true = f2.add_load_value(Value::true_val()); @@ -506,7 +511,7 @@ fn test_module() { let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap(); let mut main = modb - .define_function("main", Signature::new(BOOL_T, type_row![BOOL_T; 2])) + .define_function("main", Signature::new(bool_t(), vec![bool_t(); 2])) .unwrap(); let [inp] = main.input_wires_arr(); let cst_false = main.add_load_value(Value::false_val()); From df434e7be0e790f802f396b8de2665defb6da690 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 10:31:45 +0000 Subject: [PATCH 267/281] clippy --- hugr-passes/src/dataflow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f6e710f66..f9742953a 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -52,7 +52,7 @@ pub enum ConstLocation<'a> { Node(Node), } -impl<'a> From for ConstLocation<'a> { +impl From for ConstLocation<'_> { fn from(value: Node) -> Self { ConstLocation::Node(value) } From 650cdfd2ef2e020839d835b95f2abac55bee46b0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 26 Nov 2024 10:22:33 +0000 Subject: [PATCH 268/281] And separate DFContext from HugrView once again --- hugr-passes/src/dataflow.rs | 8 +-- hugr-passes/src/dataflow/datalog.rs | 102 ++++++++++++++-------------- hugr-passes/src/dataflow/results.rs | 31 ++++----- hugr-passes/src/dataflow/test.rs | 41 +++++------ 4 files changed, 84 insertions(+), 98 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f9742953a..5faf3d733 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -14,15 +14,11 @@ pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::types::TypeArg; -use hugr_core::{Hugr, HugrView, Node}; +use hugr_core::{Hugr, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader + std::ops::Deref { - /// Type of view contained within this context. (Ideally we'd constrain - /// by `std::ops::Deref` but that's not stable yet.) - type View: HugrView; - +pub trait DFContext: ConstLoader { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 2e690001e..74e4a0b9e 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -72,20 +72,20 @@ impl Machine { /// # Panics /// May panic in various ways if the Hugr is invalid; /// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`. - pub fn run>( + pub fn run( mut self, - context: C, + context: &impl DFContext, + hugr: H, in_values: impl IntoIterator)>, - ) -> AnalysisResults { + ) -> AnalysisResults { let mut in_values = in_values.into_iter(); - let root = context.root(); + let root = hugr.root(); // Some nodes do not accept values as dataflow inputs - for these // we must find the corresponding Input node. - let input_node_parent = match context.get_optype(root) { + let input_node_parent = match hugr.get_optype(root) { OpType::Module(_) => { - let main = context.children(root).find(|n| { - context - .get_optype(*n) + let main = hugr.children(root).find(|n| { + hugr.get_optype(*n) .as_func_defn() .is_some_and(|f| f.name == "main") }); @@ -103,7 +103,7 @@ impl Machine { // analysis must produce Top == we-know-nothing, not `V` !) if let Some(p) = input_node_parent { self.prepopulate_df_inputs( - &*context, + &hugr, p, in_values.map(|(p, v)| (OutgoingPort::from(p.index()), v)), ); @@ -115,7 +115,7 @@ impl Machine { .iter() .filter_map(|(n, p, _)| (n == &root).then_some(*p)) .collect(); - for p in context.signature(root).unwrap_or_default().input_ports() { + for p in hugr.signature(root).unwrap_or_default().input_ports() { if !got_inputs.contains(&p) { self.0.push((root, p, PartialValue::Top)); } @@ -123,14 +123,15 @@ impl Machine { } // Note/TODO, if analysis is running on a subregion then we should do similar // for any nonlocal edges providing values from outside the region. - run_datalog(context, self.0) + run_datalog(context, hugr, self.0) } } -pub(super) fn run_datalog>( - ctx: C, +pub(super) fn run_datalog( + ctx: &impl DFContext, + hugr: H, in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, -) -> AnalysisResults { +) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. #![allow( @@ -150,49 +151,49 @@ pub(super) fn run_datalog>( lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value lattice node_in_value_row(Node, ValueRow); // 's inputs are - node(n) <-- for n in ctx.nodes(); + node(n) <-- for n in hugr.nodes(); - in_wire(n, p) <-- node(n), for (p,_) in ctx.in_value_types(*n); // Note, gets connected inports only - out_wire(n, p) <-- node(n), for (p,_) in ctx.out_value_types(*n); // (and likewise) + in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise) parent_of_node(parent, child) <-- - node(child), if let Some(parent) = ctx.get_parent(*child); + node(child), if let Some(parent) = hugr.get_parent(*child); - input_child(parent, input) <-- node(parent), if let Some([input, _output]) = ctx.get_io(*parent); - output_child(parent, output) <-- node(parent), if let Some([_input, output]) = ctx.get_io(*parent); + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = hugr.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = hugr.get_io(*parent); // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); // Outputs to inputs in_wire_value(n, ip, v) <-- in_wire(n, ip), - if let Some((m, op)) = ctx.single_linked_output(*n, *ip), + if let Some((m, op)) = hugr.single_linked_output(*n, *ip), out_wire_value(m, op, v); // Prepopulate in_wire_value from in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), - if let Some(sig) = ctx.signature(*n), + if let Some(sig) = hugr.signature(*n), if sig.input_ports().contains(p); // Assemble node_in_value_row from in_wire_value's - node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); - node_in_value_row(n, ValueRow::new(ctx.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); + node_in_value_row(n, ValueRow::new(hugr.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire_value(n, p, v); // Interpret leaf ops out_wire_value(n, p, v) <-- node(n), - let op_t = ctx.get_optype(*n), + let op_t = hugr.get_optype(*n), if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(&ctx, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(ctx, &hugr, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG -------------------- relation dfg_node(Node); // is a `DFG` - dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); + dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), input_child(dfg, i), in_wire_value(dfg, p, v); @@ -203,13 +204,13 @@ pub(super) fn run_datalog>( // TailLoop -------------------- // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), - if ctx.get_optype(*tl).is_tail_loop(), + if hugr.get_optype(*tl).is_tail_loop(), input_child(tl, i), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), input_child(tl, in_n), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... @@ -219,7 +220,7 @@ pub(super) fn run_datalog>( // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ... and select just what's possible for BREAK_TAG, if anything @@ -230,16 +231,16 @@ pub(super) fn run_datalog>( // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); case_node(cond, i, case) <-- node(cond), - if ctx.get_optype(*cond).is_conditional(), - for (i, case) in ctx.children(*cond).enumerate(), - if ctx.get_optype(case).is_case(); + if hugr.get_optype(*cond).is_conditional(), + for (i, case) in hugr.children(*cond).enumerate(), + if hugr.get_optype(case).is_case(); // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), input_child(case, i_node), node_in_value_row(cond, in_row), - let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), + let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -258,39 +259,39 @@ pub(super) fn run_datalog>( // CFG -------------------- relation cfg_node(Node); // is a `CFG` - cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); + cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); // In `CFG` , basic block is reachable given our knowledge of predicates: relation bb_reachable(Node, Node); - bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), output_child(pred, pred_out), in_wire_value(pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in ctx.output_neighbours(*pred).enumerate(), + for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), - if let Some(entry) = ctx.children(*cfg).next(), + if let Some(entry) = hugr.children(*cfg).next(), input_child(entry, i_node), in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of : relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), - for blk in ctx.children(*cfg), - if ctx.get_optype(blk).is_dataflow_block(), + for blk in hugr.children(*cfg), + if hugr.get_optype(blk).is_dataflow_block(), input_child(blk, inp); // Outputs of each reachable block propagated to successor block or CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- bb_reachable(cfg, pred), - if let Some(df_block) = ctx.get_optype(*pred).as_dataflow_block(), - for (succ_n, succ) in ctx.output_neighbours(*pred).enumerate(), + if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), @@ -301,8 +302,8 @@ pub(super) fn run_datalog>( relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), - if ctx.get_optype(*call).is_call(), - if let Some(func_defn) = ctx.static_source(*call); + if hugr.get_optype(*call).is_call(), + if let Some(func_defn) = hugr.static_source(*call); out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), @@ -320,7 +321,7 @@ pub(super) fn run_datalog>( .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - ctx, + hugr, out_wire_values, in_wire_value: all_results.in_wire_value, case_reachable: all_results.case_reachable, @@ -330,11 +331,12 @@ pub(super) fn run_datalog>( fn propagate_leaf_op( ctx: &impl DFContext, + hugr: &impl HugrView, n: Node, ins: &[PV], num_outs: usize, ) -> Option> { - match ctx.get_optype(n) { + match hugr.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( @@ -356,16 +358,16 @@ fn propagate_leaf_op( OpType::Const(_) => None, // handled by LoadConstant: OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant - let const_node = ctx + let const_node = hugr .single_linked_output(n, load_op.constant_port()) .unwrap() .0; - let const_val = ctx.get_optype(const_node).as_const().unwrap().value(); + let const_val = hugr.get_optype(const_node).as_const().unwrap().value(); Some(ValueRow::singleton(partial_from_const(ctx, n, const_val))) } OpType::LoadFunction(load_op) => { assert!(ins.is_empty()); // static edge - let func_node = ctx + let func_node = hugr .single_linked_output(n, load_op.function_port()) .unwrap() .0; diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index d60b89e36..0f4704b42 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,24 +2,19 @@ use std::collections::HashMap; use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialValue, Sum}; +use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults> { - pub(super) ctx: C, +pub struct AnalysisResults { + pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, pub(super) bb_reachable: Vec<(Node, Node)>, pub(super) out_wire_values: HashMap>, } -impl> AnalysisResults { - /// Allows to use the [HugrView] contained within - pub fn hugr(&self) -> &C::View { - &self.ctx - } - +impl AnalysisResults { /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -31,8 +26,8 @@ impl> AnalysisResults { /// /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates(&self, node: Node) -> Option { - self.hugr().get_optype(node).as_tail_loop()?; - let [_, out] = self.hugr().get_io(node).unwrap(); + self.hugr.get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( self.in_wire_value .iter() @@ -49,9 +44,9 @@ impl> AnalysisResults { /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, case: Node) -> Option { - self.hugr().get_optype(case).as_case()?; - let cond = self.hugr().get_parent(case)?; - self.hugr().get_optype(cond).as_conditional()?; + self.hugr.get_optype(case).as_case()?; + let cond = self.hugr.get_parent(case)?; + self.hugr.get_optype(cond).as_conditional()?; Some( self.case_reachable .iter() @@ -66,9 +61,9 @@ impl> AnalysisResults { /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, bb: Node) -> Option { - let cfg = self.hugr().get_parent(bb)?; // Not really required...?? - self.hugr().get_optype(cfg).as_cfg()?; - let t = self.hugr().get_optype(bb); + let cfg = self.hugr.get_parent(bb)?; // Not really required...?? + self.hugr.get_optype(cfg).as_cfg()?; + let t = self.hugr.get_optype(bb); (t.is_dataflow_block() || t.is_exit_block()).then(|| { self.bb_reachable .iter() @@ -93,7 +88,7 @@ impl> AnalysisResults { { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self - .hugr() + .hugr .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .ok_or(None)?; diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 7ea8c6eb0..7889b364b 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -28,18 +28,10 @@ enum Void {} impl AbstractValue for Void {} -struct TestContext(H); +struct TestContext; -impl std::ops::Deref for TestContext { - type Target = H; - fn deref(&self) -> &H { - &self.0 - } -} -impl ConstLoader for TestContext {} -impl DFContext for TestContext { - type View = H; -} +impl ConstLoader for TestContext {} +impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { @@ -68,7 +60,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); - let results = Machine::default().run(TestContext(hugr), []); + let results = Machine::default().run(&TestContext, &hugr, []); let x: Value = results.try_read_wire_concrete(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -84,7 +76,7 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); - let results = Machine::default().run(TestContext(hugr), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o1_r: Value = results.try_read_wire_concrete(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -110,7 +102,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(hugr), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap(); assert_eq!(o_r, r_v); @@ -145,7 +137,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(&hugr), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -183,7 +175,7 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(TestContext(&hugr), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -246,7 +238,7 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(TestContext(&hugr), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -298,7 +290,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - let results = Machine::default().run(TestContext(hugr), [(0.into(), arg_pv)]); + let results = Machine::default().run(&TestContext, &hugr, [(0.into(), arg_pv)]); let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); @@ -396,7 +388,8 @@ fn test_cfg( ) { let root = xor_and_cfg.root(); let results = Machine::default().run( - TestContext(xor_and_cfg), + &TestContext, + &xor_and_cfg, [(0.into(), inp0), (1.into(), inp1)], ); @@ -432,7 +425,7 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let results = Machine::default().run(TestContext(&hugr), [(0.into(), inp0), (1.into(), inp1)]); + let results = Machine::default().run(&TestContext, &hugr, [(0.into(), inp0), (1.into(), inp1)]); let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: @@ -454,7 +447,7 @@ fn test_region() { .finish_prelude_hugr_with_outputs(nested.outputs()) .unwrap(); let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); - let whole_hugr_results = Machine::default().run(TestContext(&hugr), [(0.into(), pv_true())]); + let whole_hugr_results = Machine::default().run(&TestContext, &hugr, [(0.into(), pv_true())]); assert_eq!( whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) @@ -474,7 +467,7 @@ fn test_region() { let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); // Do not provide a value on the second input (constant false in the whole hugr, above) - let sub_hugr_results = Machine::default().run(TestContext(subview), [(0.into(), pv_true())]); + let sub_hugr_results = Machine::default().run(&TestContext, subview, [(0.into(), pv_true())]); assert_eq!( sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) @@ -522,7 +515,7 @@ fn test_module() { let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); - let results_just_main = Machine::default().run(TestContext(&hugr), [(0.into(), pv_true())]); + let results_just_main = Machine::default().run(&TestContext, &hugr, [(0.into(), pv_true())]); assert_eq!( results_just_main.read_out_wire(Wire::new(f2_inp, 0)), Some(PartialValue::Bottom) @@ -543,7 +536,7 @@ fn test_module() { let results_two_calls = { let mut m = Machine::default(); m.prepopulate_df_inputs(&hugr, f2.node(), [(0.into(), pv_true())]); - m.run(TestContext(&hugr), [(0.into(), pv_false())]) + m.run(&TestContext, &hugr, [(0.into(), pv_false())]) }; for call in [f2_call, main_call] { From 071c7dd3ebf9dfe7cfa683febf333009f2e7d2d0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 26 Nov 2024 10:29:58 +0000 Subject: [PATCH 269/281] Store HugrView (not DFContext) in Machine --- hugr-passes/src/dataflow/datalog.rs | 51 ++++++++++++++--------------- hugr-passes/src/dataflow/test.rs | 35 +++++++++----------- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 74e4a0b9e..74dfc68ac 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -19,27 +19,28 @@ use super::{ type PV = PartialValue; /// Basic structure for performing an analysis. Usage: -/// 1. Get a new instance via [Self::default()] +/// 1. Make a new instance via [Self::new()] /// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] and/or /// [Self::prepopulate_df_inputs] with initial values. /// For example, to analyse a [Module](OpType::Module)-rooted Hugr as a library, /// [Self::prepopulate_df_inputs] can be used on each externally-callable /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] -pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); +pub struct Machine(H, Vec<(Node, IncomingPort, PartialValue)>); -/// derived-Default requires the context to be Defaultable, which is unnecessary -impl Default for Machine { - fn default() -> Self { - Self(Default::default()) +impl Machine { + /// Create a new Machine to analyse the given Hugr(View) + pub fn new(hugr: H) -> Self { + Self(hugr, Default::default()) } } -impl Machine { +impl Machine { /// Provide initial values for a wire - these will be `join`d with any computed. - pub fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue) { - self.0.extend( - h.linked_inputs(w.node(), w.source()) + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + self.1.extend( + self.0 + .linked_inputs(w.node(), w.source()) .map(|(n, inp)| (n, inp, v.clone())), ); } @@ -49,18 +50,17 @@ impl Machine { /// Any out-ports of said same `Input` node, not given values by `in_values`, are set to [PartialValue::Top]. pub fn prepopulate_df_inputs( &mut self, - h: &impl HugrView, parent: Node, in_values: impl IntoIterator)>, ) { // Put values onto out-wires of Input node - let [inp, _] = h.get_io(parent).unwrap(); - let mut vals = vec![PartialValue::Top; h.signature(inp).unwrap().output_types().len()]; + let [inp, _] = self.0.get_io(parent).unwrap(); + let mut vals = vec![PartialValue::Top; self.0.signature(inp).unwrap().output_types().len()]; for (ip, v) in in_values { vals[ip.index()] = v; } for (i, v) in vals.into_iter().enumerate() { - self.prepopulate_wire(h, Wire::new(inp, i), v); + self.prepopulate_wire(Wire::new(inp, i), v); } } @@ -72,20 +72,20 @@ impl Machine { /// # Panics /// May panic in various ways if the Hugr is invalid; /// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`. - pub fn run( + pub fn run( mut self, context: &impl DFContext, - hugr: H, in_values: impl IntoIterator)>, ) -> AnalysisResults { let mut in_values = in_values.into_iter(); - let root = hugr.root(); + let root = self.0.root(); // Some nodes do not accept values as dataflow inputs - for these // we must find the corresponding Input node. - let input_node_parent = match hugr.get_optype(root) { + let input_node_parent = match self.0.get_optype(root) { OpType::Module(_) => { - let main = hugr.children(root).find(|n| { - hugr.get_optype(*n) + let main = self.0.children(root).find(|n| { + self.0 + .get_optype(*n) .as_func_defn() .is_some_and(|f| f.name == "main") }); @@ -103,27 +103,26 @@ impl Machine { // analysis must produce Top == we-know-nothing, not `V` !) if let Some(p) = input_node_parent { self.prepopulate_df_inputs( - &hugr, p, in_values.map(|(p, v)| (OutgoingPort::from(p.index()), v)), ); } else { // Put values onto in-wires of root node, datalog will do the rest - self.0.extend(in_values.map(|(p, v)| (root, p, v))); + self.1.extend(in_values.map(|(p, v)| (root, p, v))); let got_inputs: HashSet<_, RandomState> = self - .0 + .1 .iter() .filter_map(|(n, p, _)| (n == &root).then_some(*p)) .collect(); - for p in hugr.signature(root).unwrap_or_default().input_ports() { + for p in self.0.signature(root).unwrap_or_default().input_ports() { if !got_inputs.contains(&p) { - self.0.push((root, p, PartialValue::Top)); + self.1.push((root, p, PartialValue::Top)); } } } // Note/TODO, if analysis is running on a subregion then we should do similar // for any nonlocal edges providing values from outside the region. - run_datalog(context, hugr, self.0) + run_datalog(context, self.0, self.1) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 7889b364b..8f9d18bf9 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -60,7 +60,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let x: Value = results.try_read_wire_concrete(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -76,7 +76,7 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o1_r: Value = results.try_read_wire_concrete(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -102,7 +102,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap(); assert_eq!(o_r, r_v); @@ -137,7 +137,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -175,7 +175,7 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -238,7 +238,7 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -290,7 +290,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - let results = Machine::default().run(&TestContext, &hugr, [(0.into(), arg_pv)]); + let results = Machine::new(&hugr).run(&TestContext, [(0.into(), arg_pv)]); let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); @@ -387,11 +387,8 @@ fn test_cfg( xor_and_cfg: Hugr, ) { let root = xor_and_cfg.root(); - let results = Machine::default().run( - &TestContext, - &xor_and_cfg, - [(0.into(), inp0), (1.into(), inp1)], - ); + let results = + Machine::new(&xor_and_cfg).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]); assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); @@ -425,7 +422,7 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let results = Machine::default().run(&TestContext, &hugr, [(0.into(), inp0), (1.into(), inp1)]); + let results = Machine::new(&hugr).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]); let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: @@ -447,7 +444,7 @@ fn test_region() { .finish_prelude_hugr_with_outputs(nested.outputs()) .unwrap(); let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); - let whole_hugr_results = Machine::default().run(&TestContext, &hugr, [(0.into(), pv_true())]); + let whole_hugr_results = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]); assert_eq!( whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) @@ -467,7 +464,7 @@ fn test_region() { let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); // Do not provide a value on the second input (constant false in the whole hugr, above) - let sub_hugr_results = Machine::default().run(&TestContext, subview, [(0.into(), pv_true())]); + let sub_hugr_results = Machine::new(subview).run(&TestContext, [(0.into(), pv_true())]); assert_eq!( sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) @@ -515,7 +512,7 @@ fn test_module() { let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); - let results_just_main = Machine::default().run(&TestContext, &hugr, [(0.into(), pv_true())]); + let results_just_main = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]); assert_eq!( results_just_main.read_out_wire(Wire::new(f2_inp, 0)), Some(PartialValue::Bottom) @@ -534,9 +531,9 @@ fn test_module() { } let results_two_calls = { - let mut m = Machine::default(); - m.prepopulate_df_inputs(&hugr, f2.node(), [(0.into(), pv_true())]); - m.run(&TestContext, &hugr, [(0.into(), pv_false())]) + let mut m = Machine::new(&hugr); + m.prepopulate_df_inputs(f2.node(), [(0.into(), pv_true())]); + m.run(&TestContext, [(0.into(), pv_false())]) }; for call in [f2_call, main_call] { From c5bd7b0293f0de7e9c2e7b5f4be9143f53d8ba91 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 26 Nov 2024 10:32:10 +0000 Subject: [PATCH 270/281] Machine::run owns DFContext. Or &mut ?? (TODO update interpret_leaf_op to &mut self) --- hugr-passes/src/dataflow/datalog.rs | 6 +++--- hugr-passes/src/dataflow/test.rs | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 74dfc68ac..490a0fb36 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -74,7 +74,7 @@ impl Machine { /// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`. pub fn run( mut self, - context: &impl DFContext, + context: impl DFContext, in_values: impl IntoIterator)>, ) -> AnalysisResults { let mut in_values = in_values.into_iter(); @@ -127,7 +127,7 @@ impl Machine { } pub(super) fn run_datalog( - ctx: &impl DFContext, + ctx: impl DFContext, hugr: H, in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, ) -> AnalysisResults { @@ -187,7 +187,7 @@ pub(super) fn run_datalog( if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(ctx, &hugr, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(&ctx, &hugr, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG -------------------- diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 8f9d18bf9..6df953070 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -60,7 +60,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); - let results = Machine::new(&hugr).run(&TestContext, []); + let results = Machine::new(&hugr).run(TestContext, []); let x: Value = results.try_read_wire_concrete(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -76,7 +76,7 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); - let results = Machine::new(&hugr).run(&TestContext, []); + let results = Machine::new(&hugr).run(TestContext, []); let o1_r: Value = results.try_read_wire_concrete(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -102,7 +102,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::new(&hugr).run(&TestContext, []); + let results = Machine::new(&hugr).run(TestContext, []); let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap(); assert_eq!(o_r, r_v); @@ -137,7 +137,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::new(&hugr).run(&TestContext, []); + let results = Machine::new(&hugr).run(TestContext, []); let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -175,7 +175,7 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::new(&hugr).run(&TestContext, []); + let results = Machine::new(&hugr).run(TestContext, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -238,7 +238,7 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::new(&hugr).run(&TestContext, []); + let results = Machine::new(&hugr).run(TestContext, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -290,7 +290,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - let results = Machine::new(&hugr).run(&TestContext, [(0.into(), arg_pv)]); + let results = Machine::new(&hugr).run(TestContext, [(0.into(), arg_pv)]); let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); @@ -388,7 +388,7 @@ fn test_cfg( ) { let root = xor_and_cfg.root(); let results = - Machine::new(&xor_and_cfg).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]); + Machine::new(&xor_and_cfg).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); @@ -422,7 +422,7 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let results = Machine::new(&hugr).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]); + let results = Machine::new(&hugr).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: @@ -444,7 +444,7 @@ fn test_region() { .finish_prelude_hugr_with_outputs(nested.outputs()) .unwrap(); let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); - let whole_hugr_results = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]); + let whole_hugr_results = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); assert_eq!( whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) @@ -464,7 +464,7 @@ fn test_region() { let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); // Do not provide a value on the second input (constant false in the whole hugr, above) - let sub_hugr_results = Machine::new(subview).run(&TestContext, [(0.into(), pv_true())]); + let sub_hugr_results = Machine::new(subview).run(TestContext, [(0.into(), pv_true())]); assert_eq!( sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) @@ -512,7 +512,7 @@ fn test_module() { let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); - let results_just_main = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]); + let results_just_main = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); assert_eq!( results_just_main.read_out_wire(Wire::new(f2_inp, 0)), Some(PartialValue::Bottom) @@ -533,7 +533,7 @@ fn test_module() { let results_two_calls = { let mut m = Machine::new(&hugr); m.prepopulate_df_inputs(f2.node(), [(0.into(), pv_true())]); - m.run(&TestContext, [(0.into(), pv_false())]) + m.run(TestContext, [(0.into(), pv_false())]) }; for call in [f2_call, main_call] { From eecdb22c957a2b7d0576644e60aae2e3222cefe8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 10:30:31 +0000 Subject: [PATCH 271/281] interpret_leaf_op is &mut to allow impls that do caching etc. --- hugr-passes/src/dataflow.rs | 4 ++-- hugr-passes/src/dataflow/datalog.rs | 6 +++--- hugr-passes/src/dataflow/test.rs | 3 +-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 5faf3d733..bb3023c38 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -20,7 +20,7 @@ use hugr_core::{Hugr, Node}; /// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: ConstLoader { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. - /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. + /// For extension ops only, excluding [MakeTuple] and [UnpackTuple] which are handled automatically. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] /// which is the correct value to leave if nothing can be deduced about that output. /// (The default does nothing, i.e. leaves `Top` for all outputs.) @@ -28,7 +28,7 @@ pub trait DFContext: ConstLoader { /// [MakeTuple]: hugr_core::extension::prelude::MakeTuple /// [UnpackTuple]: hugr_core::extension::prelude::UnpackTuple fn interpret_leaf_op( - &self, + &mut self, _node: Node, _e: &ExtensionOp, _ins: &[PartialValue], diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 490a0fb36..172d87c26 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -127,7 +127,7 @@ impl Machine { } pub(super) fn run_datalog( - ctx: impl DFContext, + mut ctx: impl DFContext, hugr: H, in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, ) -> AnalysisResults { @@ -187,7 +187,7 @@ pub(super) fn run_datalog( if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(&ctx, &hugr, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(&mut ctx, &hugr, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG -------------------- @@ -329,7 +329,7 @@ pub(super) fn run_datalog( } fn propagate_leaf_op( - ctx: &impl DFContext, + ctx: &mut impl DFContext, hugr: &impl HugrView, n: Node, ins: &[PV], diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 6df953070..13815d186 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -387,8 +387,7 @@ fn test_cfg( xor_and_cfg: Hugr, ) { let root = xor_and_cfg.root(); - let results = - Machine::new(&xor_and_cfg).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); + let results = Machine::new(&xor_and_cfg).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); From 79e98f2e3436987e75fa3d84696ae8d0a4bb6130 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 12:44:58 +0000 Subject: [PATCH 272/281] Fix post-merge extension issues by using TEST_REG everywhere --- hugr-passes/src/const_fold/test.rs | 49 ++++++++++++++---------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 7958a1194..35888f1b5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -4,7 +4,6 @@ use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; -use crate::test::TEST_REG; use hugr_core::builder::{ endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, @@ -13,12 +12,11 @@ use hugr_core::extension::prelude::{ bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, MakeTuple, UnpackTuple, }; -use hugr_core::extension::ExtensionRegistry; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::{constant::CustomConst, handle::BasicBlockID, OpTag, OpTrait, OpType, Value}; use hugr_core::std_extensions::arithmetic::{ - self, conversions::ConvertOpDef, float_ops::FloatOps, float_types::{float64_type, ConstF64}, @@ -29,8 +27,10 @@ use hugr_core::std_extensions::logic::LogicOp; use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; -use super::{constant_fold_pass, ConstFoldContext, ConstFoldPass, ValueHandle}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::test::TEST_REG; + +use super::{constant_fold_pass, ConstFoldContext, ConstFoldPass, ValueHandle}; #[rstest] #[case(ConstInt::new_u(4, 2).unwrap(), true)] @@ -1423,12 +1423,11 @@ fn test_via_part_unknown_tuple() { let res = builder .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [a, c]) .unwrap(); - let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); let mut hugr = builder - .finish_hugr_with_outputs(res.outputs(), ®) + .finish_hugr_with_outputs(res.outputs(), &TEST_REG) .unwrap(); - constant_fold_pass(&mut hugr, ®); + constant_fold_pass(&mut hugr, &TEST_REG); // We expect: root dfg, input, output, const 9, load constant, iadd let mut expected_op_tags: HashSet<_, std::hash::RandomState> = [ @@ -1452,8 +1451,7 @@ fn test_via_part_unknown_tuple() { assert!(expected_op_tags.is_empty()); } -fn tail_loop_hugr(int_cst: ConstInt) -> (Hugr, ExtensionRegistry) { - let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); +fn tail_loop_hugr(int_cst: ConstInt) -> Hugr { let int_ty = int_cst.get_type(); let lw = int_cst.log_width(); let mut builder = DFGBuilder::new(inout_sig(bool_t(), int_ty.clone())).unwrap(); @@ -1471,17 +1469,17 @@ fn tail_loop_hugr(int_cst: ConstInt) -> (Hugr, ExtensionRegistry) { .unwrap(); let hugr = builder - .finish_hugr_with_outputs(add.outputs(), ®) + .finish_hugr_with_outputs(add.outputs(), &TEST_REG) .unwrap(); - (hugr, reg) + hugr } #[test] fn test_tail_loop_unknown() { let cst5 = ConstInt::new_u(3, 5).unwrap(); - let (mut h, reg) = tail_loop_hugr(cst5.clone()); + let mut h = tail_loop_hugr(cst5.clone()); - constant_fold_pass(&mut h, ®); + constant_fold_pass(&mut h, &TEST_REG); // Must keep the loop, even though we know the output, in case the output doesn't happen assert_eq!(h.node_count(), 12); let tl = h @@ -1509,7 +1507,7 @@ fn test_tail_loop_unknown() { .map(tag_string) .sorted() .collect::>(), - Vec::from([ + vec![ "Const", "Const", "Input", @@ -1517,7 +1515,7 @@ fn test_tail_loop_unknown() { "LoadConst", "Output", "TailLoop" - ]) + ] ); assert_eq!( @@ -1563,26 +1561,25 @@ fn test_tail_loop_unknown() { #[test] fn test_tail_loop_never_iterates() { - let (mut h, reg) = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); + let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); ConstFoldPass::default() .with_inputs([(0, Value::true_val())]) // true = 1 = break - .run(&mut h, ®) + .run(&mut h, &TEST_REG) .unwrap(); assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); } #[test] fn test_tail_loop_increase_termination() { - let (mut h, reg) = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); + let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); ConstFoldPass::default() .allow_increase_termination() - .run(&mut h, ®) + .run(&mut h, &TEST_REG) .unwrap(); assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); } -fn cfg_hugr() -> (Hugr, ExtensionRegistry) { - let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); +fn cfg_hugr() -> Hugr { let int_ty = INT_TYPES[4].clone(); let mut builder = DFGBuilder::new(inout_sig(vec![bool_t(); 2], int_ty.clone())).unwrap(); let [p, q] = builder.input_wires_arr(); @@ -1621,9 +1618,9 @@ fn cfg_hugr() -> (Hugr, ExtensionRegistry) { let cfg = cfg.finish_sub_container().unwrap(); let nested = nested.finish_with_outputs(cfg.outputs()).unwrap(); let hugr = builder - .finish_hugr_with_outputs(nested.outputs(), ®) + .finish_hugr_with_outputs(nested.outputs(), &TEST_REG) .unwrap(); - (hugr, reg) + hugr } #[rstest] @@ -1637,11 +1634,11 @@ fn test_cfg( #[case] fold_blk: bool, #[case] fold_res: Option, ) { - let (backup, reg) = cfg_hugr(); + let backup = cfg_hugr(); let mut hugr = backup.clone(); let pass = ConstFoldPass::default() .with_inputs(inputs.into_iter().map(|(p, b)| (*p, Value::from_bool(*b)))); - pass.run(&mut hugr, ®).unwrap(); + pass.run(&mut hugr, &TEST_REG).unwrap(); // CFG inside DFG retained let nested = hugr .children(hugr.root()) @@ -1701,7 +1698,7 @@ fn test_cfg( let mut hugr2 = backup; pass.allow_increase_termination() - .run(&mut hugr2, ®) + .run(&mut hugr2, &TEST_REG) .unwrap(); assert_fully_folded(&hugr2, &res_v); } else { From 30b860fb3d9cc2665ed48ba179ea4948ec7f6a1c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 12:52:17 +0000 Subject: [PATCH 273/281] clippy --- hugr-passes/src/const_fold.rs | 6 +++--- hugr-passes/src/const_fold/test.rs | 21 ++++++++++----------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 85c275e42..1f139b4aa 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -219,14 +219,14 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { struct ConstFoldContext<'a, H>(&'a H); -impl<'a, H: HugrView> std::ops::Deref for ConstFoldContext<'a, H> { +impl std::ops::Deref for ConstFoldContext<'_, H> { type Target = H; fn deref(&self) -> &H { self.0 } } -impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { +impl ConstLoader for ConstFoldContext<'_, H> { fn value_from_opaque(&self, loc: ConstLocation, val: &OpaqueValue) -> Option { Some(ValueHandle::new_opaque(loc, val.clone())) } @@ -254,7 +254,7 @@ impl<'a, H: HugrView> ConstLoader for ConstFoldContext<'a, H> { } } -impl<'a, H: HugrView> DFContext for ConstFoldContext<'a, H> { +impl DFContext for ConstFoldContext<'_, H> { fn interpret_leaf_op( &mut self, node: Node, diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 35888f1b5..56804944f 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -39,8 +39,8 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let mut temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&mut temp); + let temp = Hugr::default(); + let ctx: ConstFoldContext = ConstFoldContext(&temp); let v1 = partial_from_const(&ctx, n, &subject_val); let v1_subfield = { @@ -111,8 +111,8 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let mut temp = Hugr::default(); - let mut ctx = ConstFoldContext(&mut temp); + let temp = Hugr::default(); + let mut ctx = ConstFoldContext(&temp); let v_a = partial_from_const(&ctx, n_a, &f2c(a)); let v_b = partial_from_const(&ctx, n_b, &f2c(b)); assert_eq!(unwrap_float(v_a.clone()), a); @@ -1468,10 +1468,9 @@ fn tail_loop_hugr(int_cst: ConstInt) -> Hugr { .add_dataflow_op(IntOpDef::iadd.with_log_width(lw), [lcst, loop_out_w]) .unwrap(); - let hugr = builder + builder .finish_hugr_with_outputs(add.outputs(), &TEST_REG) - .unwrap(); - hugr + .unwrap() } #[test] @@ -1617,10 +1616,10 @@ fn cfg_hugr() -> Hugr { cfg.branch(&a, fals, &x).unwrap(); let cfg = cfg.finish_sub_container().unwrap(); let nested = nested.finish_with_outputs(cfg.outputs()).unwrap(); - let hugr = builder + + builder .finish_hugr_with_outputs(nested.outputs(), &TEST_REG) - .unwrap(); - hugr + .unwrap() } #[rstest] @@ -1637,7 +1636,7 @@ fn test_cfg( let backup = cfg_hugr(); let mut hugr = backup.clone(); let pass = ConstFoldPass::default() - .with_inputs(inputs.into_iter().map(|(p, b)| (*p, Value::from_bool(*b)))); + .with_inputs(inputs.iter().map(|(p, b)| (*p, Value::from_bool(*b)))); pass.run(&mut hugr, &TEST_REG).unwrap(); // CFG inside DFG retained let nested = hugr From a4bbedd9ff21fc3f09238e1acee1479ef95a00ad Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 14:07:38 +0000 Subject: [PATCH 274/281] Use 1.75-compatible RandomState --- hugr-passes/src/const_fold/test.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 56804944f..be63dae28 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,4 +1,5 @@ use std::collections::HashSet; +use std::collections::hash_map::RandomState; use itertools::Itertools; use lazy_static::lazy_static; @@ -1430,7 +1431,7 @@ fn test_via_part_unknown_tuple() { constant_fold_pass(&mut hugr, &TEST_REG); // We expect: root dfg, input, output, const 9, load constant, iadd - let mut expected_op_tags: HashSet<_, std::hash::RandomState> = [ + let mut expected_op_tags: HashSet<_, RandomState> = [ OpTag::Dfg, OpTag::Input, OpTag::Output, From 07b7d40aa53eafd8c3b758f953f5b38d762525bd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 14:12:16 +0000 Subject: [PATCH 275/281] Ooops, fmt --- hugr-passes/src/const_fold/test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index be63dae28..fdb991122 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,5 +1,5 @@ -use std::collections::HashSet; use std::collections::hash_map::RandomState; +use std::collections::HashSet; use itertools::Itertools; use lazy_static::lazy_static; From f7fd1c77942381c3840d0d42783c5b7e598546a4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 14:22:49 +0000 Subject: [PATCH 276/281] ConstFoldPass back to ConstantFoldPass --- hugr-passes/src/const_fold.rs | 6 +++--- hugr-passes/src/const_fold/test.rs | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 1f139b4aa..f1f3aacb1 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -28,13 +28,13 @@ use crate::validation::{ValidatePassError, ValidationLevel}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. -pub struct ConstFoldPass { +pub struct ConstantFoldPass { validation: ValidationLevel, allow_increase_termination: bool, inputs: HashMap, } -impl ConstFoldPass { +impl ConstantFoldPass { /// Sets the validation level used before and after the pass is run pub fn validation_level(mut self, level: ValidationLevel) -> Self { self.validation = level; @@ -214,7 +214,7 @@ fn might_diverge(results: &AnalysisResults, /// Exhaustively apply constant folding to a HUGR. pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - ConstFoldPass::default().run(h, reg).unwrap() + ConstantFoldPass::default().run(h, reg).unwrap() } struct ConstFoldContext<'a, H>(&'a H); diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index fdb991122..30f6b596c 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -31,7 +31,7 @@ use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; use crate::test::TEST_REG; -use super::{constant_fold_pass, ConstFoldContext, ConstFoldPass, ValueHandle}; +use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; #[rstest] #[case(ConstInt::new_u(4, 2).unwrap(), true)] @@ -1562,7 +1562,7 @@ fn test_tail_loop_unknown() { #[test] fn test_tail_loop_never_iterates() { let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); - ConstFoldPass::default() + ConstantFoldPass::default() .with_inputs([(0, Value::true_val())]) // true = 1 = break .run(&mut h, &TEST_REG) .unwrap(); @@ -1572,7 +1572,7 @@ fn test_tail_loop_never_iterates() { #[test] fn test_tail_loop_increase_termination() { let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); - ConstFoldPass::default() + ConstantFoldPass::default() .allow_increase_termination() .run(&mut h, &TEST_REG) .unwrap(); @@ -1636,7 +1636,7 @@ fn test_cfg( ) { let backup = cfg_hugr(); let mut hugr = backup.clone(); - let pass = ConstFoldPass::default() + let pass = ConstantFoldPass::default() .with_inputs(inputs.iter().map(|(p, b)| (*p, Value::from_bool(*b)))); pass.run(&mut hugr, &TEST_REG).unwrap(); // CFG inside DFG retained From 8e121e08f100524c2080545e94b170820a5d2831 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 9 Dec 2024 11:21:13 +0000 Subject: [PATCH 277/281] reinstate ConstFoldError --- hugr-passes/src/const_fold.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index f1f3aacb1..c520013aa 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -4,6 +4,7 @@ pub mod value_handle; use std::collections::{HashMap, HashSet, VecDeque}; +use thiserror::Error; use hugr_core::{ extension::ExtensionRegistry, @@ -34,6 +35,14 @@ pub struct ConstantFoldPass { inputs: HashMap, } +#[derive(Debug, Error)] +/// Errors produced by [ConstantFoldPass]. +pub enum ConstFoldError { + #[error(transparent)] + #[allow(missing_docs)] + Validation(#[from] ValidatePassError), +} + impl ConstantFoldPass { /// Sets the validation level used before and after the pass is run pub fn validation_level(mut self, level: ValidationLevel) -> Self { @@ -63,7 +72,7 @@ impl ConstantFoldPass { } /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -132,7 +141,7 @@ impl ConstantFoldPass { &self, hugr: &mut H, reg: &ExtensionRegistry, - ) -> Result<(), ValidatePassError> { + ) -> Result<(), ConstFoldError> { self.validation .run_validated_pass(hugr, reg, |hugr: &mut H, _| self.run_no_validate(hugr)) } From 0d0b0d1d6e3a59027c0de06e0817707eced573a1 Mon Sep 17 00:00:00 2001 From: Douglas Wilson <141026920+doug-q@users.noreply.github.com> Date: Tue, 10 Dec 2024 12:40:38 +0000 Subject: [PATCH 278/281] Update hugr-passes/src/const_fold.rs Co-authored-by: Craig Roy --- hugr-passes/src/const_fold.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index c520013aa..9b97ddbf7 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -280,7 +280,7 @@ impl DFContext for ConstFoldContext<'_, H> { .filter_map(|((i, ty), pv)| { pv.clone() .try_into_concrete(ty) - .map_or(None, |v| Some((IncomingPort::from(i), v))) + .map(|v| (IncomingPort::from(i), v)) }) .collect::>(); for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { From 0bd18e92ec79cc2584375794cc301e08b6f4e9af Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Tue, 10 Dec 2024 13:19:52 +0000 Subject: [PATCH 279/281] fixup --- hugr-passes/src/const_fold.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 9b97ddbf7..65bff8667 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -280,6 +280,7 @@ impl DFContext for ConstFoldContext<'_, H> { .filter_map(|((i, ty), pv)| { pv.clone() .try_into_concrete(ty) + .ok() .map(|v| (IncomingPort::from(i), v)) }) .collect::>(); From 8698921937dfdd3e1f90ecf9ce0306557ebac455 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Wed, 11 Dec 2024 10:27:15 +0000 Subject: [PATCH 280/281] ConstFoldError::Validation => ConstFoldError::ValidationError --- hugr-passes/src/const_fold.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 842649922..8fe4d93a9 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -39,7 +39,7 @@ pub struct ConstantFoldPass { pub enum ConstFoldError { #[error(transparent)] #[allow(missing_docs)] - Validation(#[from] ValidatePassError), + ValidationError(#[from] ValidatePassError), } impl ConstantFoldPass { From 54e3a8c59d53e6554c1161255a4f8ef9ac805f09 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Wed, 11 Dec 2024 10:33:41 +0000 Subject: [PATCH 281/281] ConstFoldError non_exhaustive --- hugr-passes/src/const_fold.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 8fe4d93a9..82af5e1dc 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -35,6 +35,7 @@ pub struct ConstantFoldPass { } #[derive(Debug, Error)] +#[non_exhaustive] /// Errors produced by [ConstantFoldPass]. pub enum ConstFoldError { #[error(transparent)]