diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 67136329b..a70cb2f32 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -17,6 +17,7 @@ bench = false [dependencies] hugr-core = { path = "../hugr-core", version = "0.13.3" } +portgraph = { workspace = true } ascent = { version = "0.7.0" } itertools = { workspace = true } lazy_static = { workspace = true } diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7bd181f5f..82af5e1dc 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -1,213 +1,291 @@ -//! Constant folding routines. +#![warn(missing_docs)] +//! Constant-folding pass. +//! An (example) use of the [dataflow analysis framework](super::dataflow). -use std::collections::{BTreeSet, HashMap}; - -use hugr_core::builder::inout_sig; -use itertools::Itertools; +pub mod value_handle; +use std::collections::{HashMap, HashSet, VecDeque}; use thiserror::Error; -use hugr_core::hugr::SimpleReplacementError; -use hugr_core::types::SumType; -use hugr_core::Direction; use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{fold_out_row, ConstFoldResult}, hugr::{ hugrmut::HugrMut, - rewrite::consts::{RemoveConst, RemoveLoadConstant}, - views::SiblingSubgraph, + views::{DescendantsGraph, ExtractHugr, HierarchyView}, + }, + ops::{ + constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, + Value, }, - ops::{OpType, Value}, - type_row, Hugr, HugrView, IncomingPort, Node, SimpleReplacement, + types::{EdgeKind, TypeArg}, + HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, }; +use value_handle::ValueHandle; +use crate::dataflow::{ + partial_from_const, AbstractValue, AnalysisResults, ConstLoader, ConstLocation, DFContext, + Machine, PartialValue, TailLoopTermination, +}; use crate::validation::{ValidatePassError, ValidationLevel}; -#[derive(Error, Debug)] -#[allow(missing_docs)] -pub enum ConstFoldError { - #[error(transparent)] - SimpleReplacementError(#[from] SimpleReplacementError), - #[error(transparent)] - ValidationError(#[from] ValidatePassError), -} - -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { validation: ValidationLevel, + allow_increase_termination: bool, + inputs: HashMap, } -impl ConstantFoldPass { - /// Create a new `ConstFoldConfig` with default configuration. - pub fn new() -> Self { - Self::default() - } +#[derive(Debug, Error)] +#[non_exhaustive] +/// Errors produced by [ConstantFoldPass]. +pub enum ConstFoldError { + #[error(transparent)] + #[allow(missing_docs)] + ValidationError(#[from] ValidatePassError), +} - /// Build a `ConstFoldConfig` with the given [ValidationLevel]. +impl ConstantFoldPass { + /// Sets the validation level used before and after the pass is run pub fn validation_level(mut self, level: ValidationLevel) -> Self { self.validation = level; self } + /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their + /// result (if/when they do terminate) is either known or not needed. + /// + /// [TailLoop]: hugr_core::ops::TailLoop + /// [CFG]: hugr_core::ops::CFG + pub fn allow_increase_termination(mut self) -> Self { + self.allow_increase_termination = true; + self + } + + /// Specifies any number of external inputs to provide to the Hugr (on root-node + /// in-ports). Each supercedes any previous value on the same in-port. + pub fn with_inputs( + mut self, + inputs: impl IntoIterator, Value)>, + ) -> Self { + self.inputs + .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); + self + } + /// Run the Constant Folding pass. + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + let fresh_node = Node::from(portgraph::NodeIndex::new( + hugr.nodes().max().map_or(0, |n| n.index() + 1), + )); + let inputs = self.inputs.iter().map(|(p, v)| { + ( + *p, + partial_from_const( + &ConstFoldContext(hugr), + ConstLocation::Field(p.index(), &fresh_node.into()), + v, + ), + ) + }); + + let results = Machine::new(&hugr).run(ConstFoldContext(hugr), inputs); + let mut keep_nodes = HashSet::new(); + self.find_needed_nodes(&results, &mut keep_nodes); + let [root_inp, _] = hugr.get_io(hugr.root()).unwrap(); + + let remove_nodes = hugr + .nodes() + .filter(|n| !keep_nodes.contains(n)) + .collect::>(); + let wires_to_break = keep_nodes + .into_iter() + .flat_map(|n| hugr.node_inputs(n).map(move |ip| (n, ip))) + .filter(|(n, ip)| { + *n != hugr.root() + && matches!(hugr.get_optype(*n).port_kind(*ip), Some(EdgeKind::Value(_))) + }) + // Note we COULD filter out (avoid breaking) wires from other nodes that we are keeping. + // This would insert fewer constants, but potentially expose less parallelism. + .filter_map(|(n, ip)| { + let (src, outp) = hugr.single_linked_output(n, ip).unwrap(); + // Avoid breaking edges from existing LoadConstant (we'd only add another) + // or from root input node (any "external inputs" provided will show up here + // - potentially also in other places which this won't catch) + (!hugr.get_optype(src).is_load_constant() && src != root_inp).then_some(( + n, + ip, + results + .try_read_wire_concrete::(Wire::new(src, outp)) + .ok()?, + )) + }) + .collect::>(); + + for (n, inport, v) in wires_to_break { + let parent = hugr.get_parent(n).unwrap(); + let datatype = v.get_type(); + // We could try hash-consing identical Consts, but not ATM + let cst = hugr.add_node_with_parent(parent, Const::new(v)); + let lcst = hugr.add_node_with_parent(parent, LoadConstant { datatype }); + hugr.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); + hugr.disconnect(n, inport); + hugr.connect(lcst, OutgoingPort::from(0), n, inport); + } + for n in remove_nodes { + hugr.remove_node(n); + } + Ok(()) + } + + /// Run the pass using this configuration pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { - self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { - loop { - // We can only safely apply a single replacement. Applying a - // replacement removes nodes and edges which may be referenced by - // further replacements returned by find_consts. Even worse, if we - // attempted to apply those replacements, expecting them to fail if - // the nodes and edges they reference had been deleted, they may - // succeed because new nodes and edges reused the ids. - // - // We could be a lot smarter here, keeping track of `LoadConstant` - // nodes and only looking at their out neighbours. - let Some((replace, removes)) = find_consts(hugr, hugr.nodes()).next() else { - break Ok(()); - }; - hugr.apply_rewrite(replace)?; - for rem in removes { - // We are optimistically applying these [RemoveLoadConstant] and - // [RemoveConst] rewrites without checking whether the nodes - // they attempt to remove have remaining uses. If they do, then - // the rewrite fails and we move on. - if let Ok(const_node) = hugr.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - let _ = hugr.apply_rewrite(RemoveConst(const_node)); + self.validation + .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) + } + + fn find_needed_nodes( + &self, + results: &AnalysisResults, + needed: &mut HashSet, + ) { + let mut q = VecDeque::new(); + let h = results.hugr(); + q.push_back(h.root()); + while let Some(n) = q.pop_front() { + if !needed.insert(n) { + continue; + }; + + if h.get_optype(n).is_cfg() { + for bb in h.children(n) { + //if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates + q.push_back(bb); + } + } else if let Some(inout) = h.get_io(n) { + // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. + q.extend(inout); // Input also necessary for legality even if unreachable + + if !self.allow_increase_termination { + // Also add on anything that might not terminate (even if results not required - + // if its results are required we'll add it by following dataflow, below.) + for ch in h.children(n) { + if might_diverge(results, ch) { + q.push_back(ch); + } } } } - }) + // Also follow dataflow demand + for (src, op) in h.all_linked_outputs(n) { + let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { + EdgeKind::Value(_) => { + h.get_optype(src).is_load_constant() + || results + .try_read_wire_concrete::(Wire::new(src, op)) + .is_err() + } + EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, + EdgeKind::ControlFlow => false, // we always include all children of a CFG above + _ => true, // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst + }; + if needs_predecessor { + q.push_back(src); + } + } + } } } -/// For a given op and consts, attempt to evaluate the op. -pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - let fold_result = match op { - OpType::Tag(t) => fold_out_row([Value::sum( - t.tag, - consts.iter().map(|(_, konst)| konst.clone()), - SumType::new(t.variants.clone()), - ) - .unwrap()]), - OpType::ExtensionOp(ext_op) => ext_op.constant_fold(consts), - _ => None, - }; - debug_assert!(fold_result.as_ref().map_or(true, |x| x.len() - == op.value_port_count(Direction::Outgoing))); - fold_result +// "Diverge" aka "never-terminate" +// TODO would be more efficient to compute this bottom-up and cache (dynamic programming) +fn might_diverge(results: &AnalysisResults, n: Node) -> bool { + let op = results.hugr().get_optype(n); + if op.is_cfg() { + // TODO if the CFG has no cycles (that are possible given predicates) + // then we could say it definitely terminates (i.e. return false) + true + } else if op.is_tail_loop() + && results.tail_loop_terminates(n).unwrap() != TailLoopTermination::NeverContinues + { + // If we can even figure out the number of iterations is bounded that would allow returning false. + true + } else { + // Node does not introduce non-termination, but still non-terminates if any of its children does + results + .hugr() + .children(n) + .any(|ch| might_diverge(results, ch)) + } } -/// Generate a graph that loads and outputs `consts` in order, validating -/// against `reg`. -fn const_graph(consts: Vec) -> Hugr { - let const_types = consts.iter().map(Value::get_type).collect_vec(); - let mut b = DFGBuilder::new(inout_sig(type_row![], const_types)).unwrap(); +/// Exhaustively apply constant folding to a HUGR. +pub fn constant_fold_pass(h: &mut H) { + ConstantFoldPass::default().run(h).unwrap() +} - let outputs = consts - .into_iter() - .map(|c| b.add_load_const(c)) - .collect_vec(); +struct ConstFoldContext<'a, H>(&'a H); - b.finish_hugr_with_outputs(outputs).unwrap() +impl std::ops::Deref for ConstFoldContext<'_, H> { + type Target = H; + fn deref(&self) -> &H { + self.0 + } } -/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, -/// return an iterator of possible constant folding rewrites. -/// -/// The [`SimpleReplacement`] replaces an operation with constants that result from -/// evaluating it, the extension registry `reg` is used to validate the -/// replacement HUGR. The vector of [`RemoveLoadConstant`] refer to the -/// LoadConstant nodes that could be removed - they are not automatically -/// removed as they may be used by other operations. -pub fn find_consts<'a, 'r: 'a>( - hugr: &'a impl HugrView, - candidate_nodes: impl IntoIterator + 'a, -) -> impl Iterator)> + 'a { - // track nodes for operations that have already been considered for folding - let mut used_neighbours = BTreeSet::new(); - - candidate_nodes - .into_iter() - .filter_map(move |n| { - // only look at LoadConstant - hugr.get_optype(n).is_load_constant().then_some(())?; - - let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; - let neighbours = hugr - .linked_inputs(n, out_p) - .filter(|(n, _)| used_neighbours.insert(*n)) - .collect_vec(); - if neighbours.is_empty() { - // no uses of LoadConstant that haven't already been considered. - return None; - } - let fold_iter = neighbours - .into_iter() - .filter_map(|(neighbour, _)| fold_op(hugr, neighbour)); - Some(fold_iter) - }) - .flatten() -} +impl ConstLoader for ConstFoldContext<'_, H> { + fn value_from_opaque(&self, loc: ConstLocation, val: &OpaqueValue) -> Option { + Some(ValueHandle::new_opaque(loc, val.clone())) + } -/// Attempt to evaluate and generate rewrites for the operation at `op_node` -fn fold_op( - hugr: &impl HugrView, - op_node: Node, -) -> Option<(SimpleReplacement, Vec)> { - // only support leaf folding for now. - let neighbour_op = hugr.get_optype(op_node); - let (in_consts, removals): (Vec<_>, Vec<_>) = hugr - .node_inputs(op_node) - .filter_map(|in_p| { - let (con_op, load_n) = get_const(hugr, op_node, in_p)?; - Some(((in_p, con_op), RemoveLoadConstant(load_n))) - }) - .unzip(); - // attempt to evaluate op - let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)? - .into_iter() - .enumerate() - .filter_map(|(i, (op_out, konst))| { - // for each used port of the op give the nu_out entry and the - // corresponding Value - hugr.single_linked_input(op_node, op_out) - .map(|np| ((np, i.into()), konst)) - }) - .unzip(); - let replacement = const_graph(consts); - let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) - .expect("Operation should form valid subgraph."); - - let simple_replace = SimpleReplacement::new( - sibling_graph, - replacement, - // no inputs to replacement - HashMap::new(), - nu_out, - ); - Some((simple_replace, removals)) -} + fn value_from_const_hugr( + &self, + loc: ConstLocation, + h: &hugr_core::Hugr, + ) -> Option { + Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) + } -/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant -/// and the LoadConstant node -fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Value, Node)> { - let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; - let load_op = hugr.get_optype(load_n).as_load_constant()?; - let const_node = hugr - .single_linked_output(load_n, load_op.constant_port())? - .0; - let const_op = hugr.get_optype(const_node).as_const()?; - - // TODO avoid const clone here - Some((const_op.as_ref().clone(), load_n)) + fn value_from_function(&self, node: Node, type_args: &[TypeArg]) -> Option { + if !type_args.is_empty() { + // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) + return None; + }; + // Returning the function body as a value, here, would be sufficient for inlining IndirectCall + // but not for transforming to a direct Call. + let func = DescendantsGraph::>::try_new(&**self, node).ok()?; + Some(ValueHandle::new_const_hugr( + ConstLocation::Node(node), + Box::new(func.extract_hugr()), + )) + } } -/// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut H) { - ConstantFoldPass::default().run(h).unwrap() +impl DFContext for ConstFoldContext<'_, H> { + fn interpret_leaf_op( + &mut self, + node: Node, + op: &ExtensionOp, + ins: &[PartialValue], + outs: &mut [PartialValue], + ) { + let sig = op.signature(); + let known_ins = sig + .input_types() + .iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + pv.clone() + .try_into_concrete(ty) + .ok() + .map(|v| (IncomingPort::from(i), v)) + }) + .collect::>(); + for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { + outs[p.index()] = + partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v); + } + } } #[cfg(test)] -pub(crate) mod test; +mod test; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 109154657..62e9cdb9e 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,29 +1,73 @@ -use crate::const_fold::constant_fold_pass; -use crate::test::TEST_REG; -use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr}; +use std::collections::hash_map::RandomState; +use std::collections::HashSet; + +use itertools::Itertools; +use lazy_static::lazy_static; +use rstest::rstest; + +use hugr_core::builder::{ + endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + SubContainer, +}; use hugr_core::extension::prelude::{ - bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, UnpackTuple, + bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, MakeTuple, + UnpackTuple, +}; + +use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; +use hugr_core::ops::{constant::CustomConst, handle::BasicBlockID, OpTag, OpTrait, OpType, Value}; +use hugr_core::std_extensions::arithmetic::{ + conversions::ConvertOpDef, + float_ops::FloatOps, + float_types::{float64_type, ConstF64}, + int_ops::IntOpDef, + int_types::{ConstInt, INT_TYPES}, }; -use hugr_core::ops::Value; -use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; -use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::logic::LogicOp; -use hugr_core::type_row; -use hugr_core::types::{Signature, Type, TypeRow, TypeRowRV}; +use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV}; +use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; -use rstest::rstest; +use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::test::TEST_REG; -use lazy_static::lazy_static; +use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; -use super::*; -use hugr_core::builder::Container; -use hugr_core::ops::OpType; -use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; -use hugr_core::std_extensions::arithmetic::float_ops::FloatOps; -use hugr_core::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; +#[rstest] +#[case(ConstInt::new_u(4, 2).unwrap(), true)] +#[case(ConstF64::new(std::f64::consts::PI), false)] +fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { + let n = Node::from(portgraph::NodeIndex::new(7)); + let st = SumType::new([vec![k.get_type()], vec![]]); + let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); + let temp = Hugr::default(); + let ctx: ConstFoldContext = ConstFoldContext(&temp); + let v1 = partial_from_const(&ctx, n, &subject_val); + + let v1_subfield = { + let PartialValue::PartialSum(ps1) = v1 else { + panic!() + }; + ps1.0 + .into_iter() + .exactly_one() + .unwrap() + .1 + .into_iter() + .exactly_one() + .unwrap() + }; + + let v2 = partial_from_const(&ctx, n, &k.into()); + if eq { + assert_eq!(v1_subfield, v2); + } else { + assert_ne!(v1_subfield, v2); + } +} /// Check that a hugr just loads and returns a single expected constant. -pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { +pub fn assert_fully_folded(h: &impl HugrView, expected_value: &Value) { assert_fully_folded_with(h, |v| v == expected_value) } @@ -32,7 +76,7 @@ pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { /// /// [CustomConst::equals_const] is not required to be implemented. Use this /// function for Values containing such a `CustomConst`. -fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { +fn assert_fully_folded_with(h: &impl HugrView, check_value: impl Fn(&Value) -> bool) { let mut node_count = 0; for node in h.children(h.root()) { @@ -63,15 +107,25 @@ fn f2c(f: f64) -> Value { #[case(23.5, 435.5, 459.0)] // c = a + b fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { - let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; - let add_op: OpType = FloatOps::fadd.into(); - let outs = fold_leaf_op(&add_op, &consts) - .unwrap() - .into_iter() - .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) - .collect_vec(); + fn unwrap_float(pv: PartialValue) -> f64 { + let v: Value = pv.try_into_concrete(&float64_type()).unwrap(); + v.get_custom_value::().unwrap().value() + } + let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); + let temp = Hugr::default(); + let mut ctx = ConstFoldContext(&temp); + let v_a = partial_from_const(&ctx, n_a, &f2c(a)); + let v_b = partial_from_const(&ctx, n_b, &f2c(b)); + assert_eq!(unwrap_float(v_a.clone()), a); + assert_eq!(unwrap_float(v_b.clone()), b); + + let mut outs = [PartialValue::Bottom]; + let OpType::ExtensionOp(add_op) = OpType::from(FloatOps::fadd) else { + panic!() + }; + ctx.interpret_leaf_op(n, &add_op, &[v_a, v_b], &mut outs); - assert_eq!(outs.as_slice(), &[(0.into(), c)]); + assert_eq!(unwrap_float(outs[0].clone()), c); } fn noargfn(outputs: impl Into) -> Signature { @@ -1240,3 +1294,296 @@ fn test_fold_int_ops() { let expected = Value::true_val(); assert_fully_folded(&h, &expected); } + +#[test] +fn test_via_part_unknown_tuple() { + // fn(x) -> let (a,_b,c) = (4,x,5) // make tuple, unpack tuple + // in a+b + let mut builder = DFGBuilder::new(endo_sig(INT_TYPES[3].clone())).unwrap(); + let [x] = builder.input_wires_arr(); + let cst4 = builder.add_load_value(ConstInt::new_u(3, 4).unwrap()); + let cst5 = builder.add_load_value(ConstInt::new_u(3, 5).unwrap()); + let tuple_ty = TypeRow::from(vec![INT_TYPES[3].clone(); 3]); + let tup = builder + .add_dataflow_op(MakeTuple::new(tuple_ty.clone()), [cst4, x, cst5]) + .unwrap(); + let untup = builder + .add_dataflow_op(UnpackTuple::new(tuple_ty), tup.outputs()) + .unwrap(); + let [a, _b, c] = untup.outputs_arr(); + let res = builder + .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [a, c]) + .unwrap(); + let mut hugr = builder.finish_hugr_with_outputs(res.outputs()).unwrap(); + + constant_fold_pass(&mut hugr); + + // We expect: root dfg, input, output, const 9, load constant, iadd + let mut expected_op_tags: HashSet<_, RandomState> = [ + OpTag::Dfg, + OpTag::Input, + OpTag::Output, + OpTag::Const, + OpTag::LoadConst, + ] + .map(|t| t.to_string()) + .into_iter() + .collect(); + for n in hugr.nodes() { + let t = hugr.get_optype(n); + let removed = expected_op_tags.remove(&t.tag().to_string()); + assert!(removed); + if let Some(c) = t.as_const() { + assert_eq!(c.value, ConstInt::new_u(3, 9).unwrap().into()) + } + } + assert!(expected_op_tags.is_empty()); +} + +fn tail_loop_hugr(int_cst: ConstInt) -> Hugr { + let int_ty = int_cst.get_type(); + let lw = int_cst.log_width(); + let mut builder = DFGBuilder::new(inout_sig(bool_t(), int_ty.clone())).unwrap(); + let [bool_w] = builder.input_wires_arr(); + let lcst = builder.add_load_value(int_cst); + let tlb = builder + .tail_loop_builder([], [(int_ty, lcst)], type_row![]) + .unwrap(); + let [i] = tlb.input_wires_arr(); + // Loop either always breaks, or always iterates, depending on the boolean input + let [loop_out_w] = tlb.finish_with_outputs(bool_w, [i]).unwrap().outputs_arr(); + // The output of the loop is the constant, if the loop terminates + let add = builder + .add_dataflow_op(IntOpDef::iadd.with_log_width(lw), [lcst, loop_out_w]) + .unwrap(); + + builder.finish_hugr_with_outputs(add.outputs()).unwrap() +} + +#[test] +fn test_tail_loop_unknown() { + let cst5 = ConstInt::new_u(3, 5).unwrap(); + let mut h = tail_loop_hugr(cst5.clone()); + + constant_fold_pass(&mut h); + // Must keep the loop, even though we know the output, in case the output doesn't happen + assert_eq!(h.node_count(), 12); + let tl = h + .nodes() + .filter(|n| h.get_optype(*n).is_tail_loop()) + .exactly_one() + .ok() + .unwrap(); + let mut dfg_nodes = Vec::new(); + let mut loop_nodes = Vec::new(); + for n in h.nodes() { + if let Some(p) = h.get_parent(n) { + if p == h.root() { + dfg_nodes.push(n) + } else { + assert_eq!(p, tl); + loop_nodes.push(n); + } + } + } + let tag_string = |n: &Node| format!("{:?}", h.get_optype(*n).tag()); + assert_eq!( + dfg_nodes + .iter() + .map(tag_string) + .sorted() + .collect::>(), + vec![ + "Const", + "Const", + "Input", + "LoadConst", + "LoadConst", + "Output", + "TailLoop" + ] + ); + + assert_eq!( + loop_nodes.iter().map(tag_string).collect::>(), + Vec::from(["Input", "Output", "Const", "LoadConst"]) + ); + + // In the loop, we have a new constant 5 instead of using the loop input + let [loop_in, loop_out] = h.get_io(tl).unwrap(); + assert!(h.input_neighbours(loop_in).next().is_none()); + let (loop_cst, v) = loop_nodes + .into_iter() + .filter_map(|n| h.get_optype(n).as_const().map(|c| (n, c.value()))) + .exactly_one() + .unwrap(); + assert_eq!(v, &cst5.clone().into()); + let loop_lcst = h.output_neighbours(loop_cst).exactly_one().ok().unwrap(); + assert_eq!(h.get_parent(loop_lcst), Some(tl)); + assert_eq!( + h.all_linked_inputs(loop_lcst).collect::>(), + vec![(loop_out, IncomingPort::from(1))] + ); + + // Outer DFG contains two constants (we know) - a 5, used by the loop, and a 10, output. + let [_, root_out] = h.get_io(h.root()).unwrap(); + let mut cst5 = Some(cst5.into()); + for n in dfg_nodes { + let Some(cst) = h.get_optype(n).as_const() else { + continue; + }; + let lcst = h.output_neighbours(n).exactly_one().ok().unwrap(); + let target = h.output_neighbours(lcst).exactly_one().ok().unwrap(); + if Some(cst.value()) == cst5.as_ref() { + cst5 = None; + assert_eq!(target, tl); + } else { + assert_eq!(cst.value(), &ConstInt::new_u(3, 10).unwrap().into()); + assert_eq!(target, root_out) + } + } + assert!(cst5.is_none()); // Found in loop +} + +#[test] +fn test_tail_loop_never_iterates() { + let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); + ConstantFoldPass::default() + .with_inputs([(0, Value::true_val())]) // true = 1 = break + .run(&mut h) + .unwrap(); + assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); +} + +#[test] +fn test_tail_loop_increase_termination() { + let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); + ConstantFoldPass::default() + .allow_increase_termination() + .run(&mut h) + .unwrap(); + assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); +} + +fn cfg_hugr() -> Hugr { + let int_ty = INT_TYPES[4].clone(); + let mut builder = DFGBuilder::new(inout_sig(vec![bool_t(); 2], int_ty.clone())).unwrap(); + let [p, q] = builder.input_wires_arr(); + let int_cst = builder.add_load_value(ConstInt::new_u(4, 1).unwrap()); + let mut nested = builder + .dfg_builder_endo([(int_ty.clone(), int_cst)]) + .unwrap(); + let [i] = nested.input_wires_arr(); + let mut cfg = nested + .cfg_builder([(int_ty.clone(), i)], int_ty.clone().into()) + .unwrap(); + let mut entry = cfg.simple_entry_builder(int_ty.clone().into(), 2).unwrap(); + let [e_i] = entry.input_wires_arr(); + let e_cst7 = entry.add_load_value(ConstInt::new_u(4, 7).unwrap()); + let e_add = entry + .add_dataflow_op(IntOpDef::iadd.with_log_width(4), [e_cst7, e_i]) + .unwrap(); + let entry = entry.finish_with_outputs(p, e_add.outputs()).unwrap(); + + let mut a = cfg + .simple_block_builder(endo_sig(int_ty.clone()), 2) + .unwrap(); + let [a_i] = a.input_wires_arr(); + let a_cst3 = a.add_load_value(ConstInt::new_u(4, 3).unwrap()); + let a_add = a + .add_dataflow_op(IntOpDef::iadd.with_log_width(4), [a_cst3, a_i]) + .unwrap(); + let a = a.finish_with_outputs(q, a_add.outputs()).unwrap(); + + let x = cfg.exit_block(); + let [tru, fals] = [1, 0]; + cfg.branch(&entry, tru, &a).unwrap(); + cfg.branch(&entry, fals, &x).unwrap(); + cfg.branch(&a, tru, &entry).unwrap(); + cfg.branch(&a, fals, &x).unwrap(); + let cfg = cfg.finish_sub_container().unwrap(); + let nested = nested.finish_with_outputs(cfg.outputs()).unwrap(); + + builder.finish_hugr_with_outputs(nested.outputs()).unwrap() +} + +#[rstest] +#[case(&[(0,false)], true, false, Some(8))] +#[case(&[(0,true), (1,false)], true, true, Some(11))] +#[case(&[(1,false)], true, true, None)] +#[case(&[], false, false, None)] +fn test_cfg( + #[case] inputs: &[(usize, bool)], + #[case] fold_entry: bool, + #[case] fold_blk: bool, + #[case] fold_res: Option, +) { + let backup = cfg_hugr(); + let mut hugr = backup.clone(); + let pass = ConstantFoldPass::default() + .with_inputs(inputs.iter().map(|(p, b)| (*p, Value::from_bool(*b)))); + pass.run(&mut hugr).unwrap(); + // CFG inside DFG retained + let nested = hugr + .children(hugr.root()) + .filter(|n| hugr.get_optype(*n).is_dfg()) + .exactly_one() + .ok() + .unwrap(); + let cfg = hugr + .nodes() + .filter(|n| hugr.get_optype(*n).is_cfg()) + .exactly_one() + .ok() + .unwrap(); + assert_eq!(hugr.get_parent(cfg), Some(nested)); + let [entry, exit, a] = hugr.children(cfg).collect::>().try_into().unwrap(); + assert!(hugr.get_optype(exit).is_exit_block()); + for (blk, is_folded, folded_cst, unfolded_cst) in + [(entry, fold_entry, 8, 7), (a, fold_blk, 11, 3)] + { + if is_folded { + assert_fully_folded( + &DescendantsGraph::::try_new(&hugr, blk).unwrap(), + &ConstInt::new_u(4, folded_cst).unwrap().into(), + ); + } else { + let mut expected_tags = + HashSet::from(["Input", "Output", "Leaf", "Const", "LoadConst"]); + for ch in hugr.children(blk) { + let tag = format!("{:?}", hugr.get_optype(ch).tag()); + assert!(expected_tags.remove(tag.as_str()), "Not found: {}", tag); + if let Some(cst) = hugr.get_optype(ch).as_const() { + assert_eq!( + cst.value(), + &ConstInt::new_u(4, unfolded_cst).unwrap().into() + ); + } else if let Some(op) = hugr.get_optype(ch).as_extension_op() { + assert_eq!(op.def().name(), "iadd"); + } + } + } + } + let output_src = hugr + .input_neighbours(hugr.get_io(hugr.root()).unwrap()[1]) + .exactly_one() + .ok() + .unwrap(); + if let Some(res_int) = fold_res { + let res_v = ConstInt::new_u(4, res_int as _).unwrap().into(); + assert!(hugr.get_optype(output_src).is_load_constant()); + let output_cst = hugr + .input_neighbours(output_src) + .exactly_one() + .ok() + .unwrap(); + let cst = hugr.get_optype(output_cst).as_const().unwrap(); + assert_eq!(cst.value(), &res_v); + + let mut hugr2 = backup; + pass.allow_increase_termination().run(&mut hugr2).unwrap(); + assert_fully_folded(&hugr2, &res_v); + } else { + assert_eq!(output_src, nested); + } +} diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs new file mode 100644 index 000000000..08e7ed0f0 --- /dev/null +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -0,0 +1,242 @@ +//! Total equality (and hence [AbstractValue] support for [Value]s +//! (by adding a source-Node and part unhashable constants) +use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use hugr_core::ops::constant::OpaqueValue; +use hugr_core::ops::Value; +use hugr_core::{Hugr, Node}; +use itertools::Either; + +use crate::dataflow::{AbstractValue, ConstLocation}; + +/// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) +#[derive(Clone, Debug)] +pub struct HashedConst { + hash: u64, + pub(super) val: Arc, +} + +impl HashedConst { + pub(super) fn try_new(val: Arc) -> Option { + let mut hasher = DefaultHasher::new(); + val.value().try_hash(&mut hasher).then(|| HashedConst { + hash: hasher.finish(), + val, + }) + } +} + +impl PartialEq for HashedConst { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash && self.val.value().equal_consts(other.val.value()) + } +} + +impl Eq for HashedConst {} + +impl Hash for HashedConst { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} + +/// An [Eq]-able and [Hash]-able leaf (non-[Sum](Value::Sum)) Value +#[derive(Clone, Debug)] +pub enum ValueHandle { + /// A [Value::Extension] that has been hashed + Hashable(HashedConst), + /// Either a [Value::Extension] that can't be hashed, or a [Value::Function]. + Unhashable { + /// The node (i.e. a [Const](hugr_core::ops::Const)) containing the constant + node: Node, + /// Indices within [Value::Sum]s containing the unhashable [Self::Unhashable::leaf] + fields: Vec, + /// The unhashable [Value::Extension] or [Value::Function] + leaf: Either, Arc>, + }, +} + +fn node_and_fields(loc: &ConstLocation) -> (Node, Vec) { + match loc { + ConstLocation::Node(n) => (*n, vec![]), + ConstLocation::Field(idx, elem) => { + let (n, mut f) = node_and_fields(elem); + f.push(*idx); + (n, f) + } + } +} + +impl ValueHandle { + /// Makes a new instance from an [OpaqueValue] given the node and (for a [Sum](Value::Sum)) + /// field indices within that (used only if the custom constant is not hashable). + pub fn new_opaque<'a>(loc: impl Into>, val: OpaqueValue) -> Self { + let arc = Arc::new(val); + let (node, fields) = node_and_fields(&loc.into()); + HashedConst::try_new(arc.clone()).map_or( + Self::Unhashable { + node, + fields, + leaf: Either::Left(arc), + }, + Self::Hashable, + ) + } + + /// New instance for a [Value::Function] found within a node + pub fn new_const_hugr<'a>(loc: impl Into>, val: Box) -> Self { + let (node, fields) = node_and_fields(&loc.into()); + Self::Unhashable { + node, + fields, + leaf: Either::Right(Arc::from(val)), + } + } +} + +impl AbstractValue for ValueHandle {} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Hashable(h1), Self::Hashable(h2)) => h1 == h2, + ( + Self::Unhashable { + node: n1, + fields: f1, + leaf: _, + }, + Self::Unhashable { + node: n2, + fields: f2, + leaf: _, + }, + ) => { + // If the keys are equal, we return true since the values must have the + // same provenance, and so be equal. If the keys are different but the + // values are equal, we could return true if we didn't impl Eq, but + // since we do impl Eq, the Hash contract prohibits us from having equal + // values with different hashes. + n1 == n2 && f1 == f2 + } + _ => false, + } + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + match self { + ValueHandle::Hashable(hc) => hc.hash(state), + ValueHandle::Unhashable { + node, + fields, + leaf: _, + } => { + node.hash(state); + fields.hash(state); + } + } + } +} + +// Unfortunately we need From for Value to be able to pass +// Value's into interpret_leaf_op. So that probably doesn't make sense... +impl From for Value { + fn from(value: ValueHandle) -> Self { + match value { + ValueHandle::Hashable(HashedConst { val, .. }) + | ValueHandle::Unhashable { + leaf: Either::Left(val), + .. + } => Value::Extension { + e: Arc::try_unwrap(val).unwrap_or_else(|a| a.as_ref().clone()), + }, + ValueHandle::Unhashable { + leaf: Either::Right(hugr), + .. + } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) + .map_err(|e| e.to_string()) + .unwrap(), + } + } +} + +#[cfg(test)] +mod test { + use hugr_core::{ + builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::{usize_t, ConstString}, + std_extensions::{ + arithmetic::{ + float_types::{float64_type, ConstF64}, + int_types::{ConstInt, INT_TYPES}, + }, + collections::ListValue, + }, + }; + + use super::*; + + #[test] + fn value_key_eq() { + let n = Node::from(portgraph::NodeIndex::new(0)); + let n2: Node = portgraph::NodeIndex::new(1).into(); + let h1 = ValueHandle::new_opaque(n, ConstString::new("foo".to_string()).into()); + let h2 = ValueHandle::new_opaque(n2, ConstString::new("foo".to_string()).into()); + let h3 = ValueHandle::new_opaque(n, ConstString::new("bar".to_string()).into()); + + assert_eq!(h1, h2); // Node ignored as constant is hashable + assert_ne!(h1, h3); + + // Hashable vs Unhashable is not equal (even with same key): + let f = ConstF64::new(std::f64::consts::PI); + let h4 = ValueHandle::new_opaque(n, f.clone().into()); + assert_ne!(h4, h1); + assert_ne!(h1, h4); + + // Unhashable vals are compared only by key, not content + let f2 = ConstF64::new(std::f64::consts::E); + assert_eq!(h4, ValueHandle::new_opaque(n, f2.clone().into())); + assert_ne!( + h4, + ValueHandle::new_opaque(ConstLocation::Field(5, &n.into()), f2.into()) + ); + + let h = Box::new(make_hugr(1)); + let h5 = ValueHandle::new_const_hugr(n, h.clone()); + assert_eq!(h5, ValueHandle::new_const_hugr(n, Box::new(make_hugr(2)))); + assert_ne!(h5, ValueHandle::new_const_hugr(n2, h)); + } + + fn make_hugr(num_wires: usize) -> Hugr { + let d = DFGBuilder::new(endo_sig(vec![usize_t(); num_wires])).unwrap(); + let inputs = d.input_wires(); + d.finish_hugr_with_outputs(inputs).unwrap() + } + + #[test] + fn value_key_list() { + let v1 = ConstInt::new_u(3, 3).unwrap(); + let v2 = ConstInt::new_u(4, 3).unwrap(); + let v3 = ConstF64::new(std::f64::consts::PI); + + let n = Node::from(portgraph::NodeIndex::new(0)); + + let lst = ListValue::new(INT_TYPES[0].clone(), [v1.into(), v2.into()]); + assert_eq!( + ValueHandle::new_opaque(n, lst.clone().into()), + ValueHandle::new_opaque(ConstLocation::Field(1, &n.into()), lst.into()) + ); + + let lst = ListValue::new(float64_type(), [v3.into()]); + assert_ne!( + ValueHandle::new_opaque(n, lst.clone().into()), + ValueHandle::new_opaque(ConstLocation::Field(3, &n.into()), lst.into()) + ); + } +} diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index bb3023c38..7cfa73835 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -87,7 +87,7 @@ pub trait ConstLoader { /// to their leaves ([Value::Extension] and [Value::Function]), /// converts these using [ConstLoader::value_from_opaque] and [ConstLoader::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. -fn partial_from_const<'a, V>( +pub fn partial_from_const<'a, V>( cl: &impl ConstLoader, loc: impl Into>, cst: &Value, diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 0f4704b42..cf0f3d5a4 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -15,6 +15,11 @@ pub struct AnalysisResults { } impl AnalysisResults { + /// Allows reading the Hugr(View) for which the results were computed + pub fn hugr(&self) -> &H { + &self.hugr + } + /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned()