From 7173a7f59a6b23d44a95dce425ef9142763d9514 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Wed, 5 Jun 2024 09:36:25 +0100 Subject: [PATCH 01/12] wip --- hugr-core/src/ops/module.rs | 2 +- hugr-core/src/types.rs | 6 +- hugr-core/src/types/custom.rs | 2 +- hugr-core/src/types/signature.rs | 2 +- hugr-core/src/types/type_param.rs | 10 +- hugr-core/src/types/type_row.rs | 2 +- hugr-passes/Cargo.toml | 1 + hugr-passes/src/const_fold2.rs | 1 + hugr-passes/src/const_fold2/datalog.rs | 237 +++++++++++++++++++++++++ hugr-passes/src/lib.rs | 1 + 10 files changed, 252 insertions(+), 12 deletions(-) create mode 100644 hugr-passes/src/const_fold2.rs create mode 100644 hugr-passes/src/const_fold2/datalog.rs diff --git a/hugr-core/src/ops/module.rs b/hugr-core/src/ops/module.rs index 40c189b68..fef5105fb 100644 --- a/hugr-core/src/ops/module.rs +++ b/hugr-core/src/ops/module.rs @@ -128,7 +128,7 @@ impl OpTrait for AliasDefn { } /// A type alias declaration. Resolved at link time. -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] pub struct AliasDecl { /// Alias name diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index cb8cd0706..dd1cd9b29 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -118,7 +118,7 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> Ty .into_inner() } -#[derive(Clone, PartialEq, Debug, Eq, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)] #[serde(tag = "s")] #[non_exhaustive] /// Representation of a Sum type. @@ -196,7 +196,7 @@ impl From for Type { } } -#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)] +#[derive(Clone, PartialEq, Debug, Eq, Hash, derive_more::Display)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] /// Core types pub enum TypeEnum { @@ -249,7 +249,7 @@ impl TypeEnum { } #[derive( - Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, + Clone, PartialEq, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, )] #[display(fmt = "{}", "_0")] #[serde(into = "serialize::SerSimpleType", from = "serialize::SerSimpleType")] diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 93d2e506a..c3c4c3f7e 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -12,7 +12,7 @@ use super::{ use super::{Type, TypeName}; /// An opaque type element. Contains the unique identifier of its definition. -#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] pub struct CustomType { extension: ExtensionId, /// Unique identifier of the opaque type. diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 1cf731eda..542b3fc0b 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -14,7 +14,7 @@ use crate::{Direction, IncomingPort, OutgoingPort, Port}; #[cfg(test)] use {crate::proptest::RecursionDepth, ::proptest::prelude::*, proptest_derive::Arbitrary}; -#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] /// Describes the edges required to/from a node, and thus, also the type of a [Graph]. /// This includes both the concept of "signature" in the spec, diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index e7019fe7c..572ac1220 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -19,7 +19,7 @@ use super::{check_typevar_decl, CustomType, Substitution, Type, TypeBound}; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] // A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid) #[derive( - Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Deserialize, serde::Serialize, + Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[display(fmt = "{}", "_0.map(|i|i.to_string()).unwrap_or(\"-\".to_string())")] #[cfg_attr(test, derive(Arbitrary))] @@ -52,7 +52,7 @@ impl UpperBound { /// [PolyFuncType]: super::PolyFuncType /// [OpDef]: crate::extension::OpDef #[derive( - Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Deserialize, serde::Serialize, + Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[non_exhaustive] #[serde(tag = "tp")] @@ -142,7 +142,7 @@ impl From for TypeParam { } /// A statically-known argument value to an operation. -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] #[non_exhaustive] #[serde(tag = "tya")] pub enum TypeArg { @@ -214,7 +214,7 @@ impl From for TypeArg { } /// Variable in a TypeArg, that is not a [TypeArg::Type] or [TypeArg::Extensions], -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] pub struct TypeArgVariable { idx: usize, cached_decl: TypeParam, @@ -352,7 +352,7 @@ impl TypeArgVariable { /// A serialized representation of a value of a [CustomType] /// restricted to equatable types. -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] pub struct CustomTypeArg { /// The type of the constant. /// (Exact matches only - the constant is exactly this type.) diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index 5be3ddc2f..17182f788 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -16,7 +16,7 @@ use delegate::delegate; use itertools::Itertools; /// List of types, used for function signatures. -#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, Eq, Debug, Hash, serde::Serialize, serde::Deserialize)] #[non_exhaustive] #[serde(transparent)] pub struct TypeRow { diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 240468f4c..4dc128a78 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -18,6 +18,7 @@ itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } +ascent = "0.6.0" [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs new file mode 100644 index 000000000..f01a78c6e --- /dev/null +++ b/hugr-passes/src/const_fold2.rs @@ -0,0 +1 @@ +mod datalog; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs new file mode 100644 index 000000000..b2e717108 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -0,0 +1,237 @@ +use std::hash::{Hash, Hasher}; + +use ascent::{ascent_run, Lattice}; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; +use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::types::{SumType, Type, TypeRow}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use itertools::{zip_eq, Itertools}; +use std::collections::HashMap; + +#[derive(PartialEq, Clone, Eq)] +struct HashableHashMap(HashMap); + +impl Hash for HashableHashMap { + fn hash(&self, state: &mut H) { + self.0.keys().for_each(|k| k.hash(state)); + self.0.values().for_each(|v| v.hash(state)); + } +} + +#[derive(PartialEq, Clone, Eq, Hash)] +enum PartialValue { + Bottom(Type), + Value(Node, Type), + PartialSum(HashableHashMap>, SumType), + Top(Type), +} + +impl PartialValue { + fn get_type(&self) -> Type { + match self { + PartialValue::Bottom(t) => t.clone(), + PartialValue::Value(_, t) => t.clone(), + PartialValue::PartialSum(_, t) => t.clone().into(), + PartialValue::Top(t) => t.clone(), + } + } + + fn top_from_hugr(hugr: &impl HugrView, node: Node, port: OutgoingPort) -> Self { + Self::Top( + hugr.signature(node) + .unwrap() + .out_port_type(port) + .unwrap() + .clone(), + ) + } + + fn from_load_constant(hugr: &impl HugrView, node: Node) -> Self { + let load_op = hugr.get_optype(node).as_load_constant().unwrap(); + let const_node = hugr + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = hugr.get_optype(const_node).as_const().unwrap(); + Self::Value(const_node, const_op.get_type()) + } + + fn tuple_from_value_row(r: &ValueRow) -> Self { + unimplemented!() + } + +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + // TODO we can do better + (self == other).then_some(std::cmp::Ordering::Equal) + } +} + +impl Lattice for PartialValue { + fn meet(self, _other: Self) -> Self { + // should not be required + todo!() + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + debug_assert_eq!(self.get_type(), other.get_type()); + match (self, other) { + (Self::Bottom(_), _) => false, + (s, rhs @ Self::Bottom(_)) => { + *s = rhs; + true + } + (_, Self::Top(_)) => false, + (s @ Self::Top(_), x) => { + *s = x; + true + } + (Self::Value(n1, t), Self::Value(n2, _)) if n1 == &n2 => false, + ( + Self::PartialSum(HashableHashMap(hm1), t), + Self::PartialSum(HashableHashMap(hm2), _), + ) => { + let mut changed = false; + for (k, v) in hm2 { + let row = hm1.entry(k).or_insert_with(|| { + changed = true; + t.get_variant(k) + .unwrap() + .iter() + .cloned() + .map(Self::Top) + .collect_vec() + }); + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } + changed + } + (s, _) => { + *s = Self::Bottom(s.get_type()); + true + } + } + } +} + +// fn input_row<'a>(inp: impl Iterator) -> impl Iterator { +// todo!() +// } + +#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] +enum ValueRow { + Values(Vec), + Bottom, +} + +impl ValueRow { + fn into_partial_value(self) -> PartialValue { + todo!() + } + + fn new(tr: &TypeRow) -> Self { + Self::Values(tr.iter().cloned().map(PartialValue::Top).collect_vec()) + } + + fn singleton(tr: &TypeRow, idx: usize, v: PartialValue) -> Self { + let mut r = Self::new(tr); + if let Self::Values(vec) = &mut r { + vec[idx] = v; + } + r + } + + fn iter(&self) -> impl Iterator { + std::iter::empty() + } +} + +impl Lattice for ValueRow { + fn meet(self, other: Self) -> Self { + todo!() + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + match (self, other) { + (Self::Bottom, _) => false, + (s, o @ Self::Bottom) => { + *s = o; + true + } + (s, Self::Values(vs2)) => { + let (b, r) = if let Self::Values(vs1) = s { + if vs1.len() == vs2.len() { + let mut changed = false; + for (v1, v2) in zip_eq(vs1.iter_mut(), vs2.into_iter()) { + changed |= v1.join_mut(v2); + } + (false, changed) + } else { + (true, true) + } + } else { + panic!("impossible") + }; + if b { + *s = Self::Bottom; + } + r + } + } + } +} + +fn node_in_value_row<'a>( + ins: impl Iterator, +) -> impl Iterator { + std::iter::empty() +} + +fn tc(hugr: &impl HugrView, node: Node) { + assert!(OpTag::DataflowParent.is_superset(hugr.get_optype(node).tag())); + let d = DescendantsGraph::<'_, Node>::try_new(hugr, node).unwrap(); + ascent_run! { + relation node(Node) = d.nodes().map(|x| (x,)).collect_vec(); + + relation in_wire(Node, IncomingPort); + in_wire(n,p) <-- node(n), for p in d.node_inputs(*n); + + relation out_wire(Node, OutgoingPort); + out_wire(n,p) <-- node(n), for p in d.node_outputs(*n); + + lattice node_in_value_row(Node, ValueRow); + node_in_value_row(n, ValueRow::new(&hugr.signature(*n).unwrap().input)) <-- node(n); + + lattice out_wire_value(Node, OutgoingPort, PartialValue); + out_wire_value(n,p, PartialValue::top_from_hugr(hugr,*n,*p)) <-- out_wire(n,p); + + node_in_value_row(n,ValueRow::singleton(&hugr.signature(*n).unwrap().input, ip.index(), v.clone())) <-- in_wire(n, ip), + if let Some((m,op)) = hugr.single_linked_output(*n, *ip), out_wire_value(m, op, ?v); + + lattice in_wire_value(Node, IncomingPort, PartialValue); + in_wire_value(n,p,v) <-- node_in_value_row(n, ?vr), for (p,v) in vr.iter(); + + relation load_constant_node(Node); + load_constant_node(n) <-- node(n), if hugr.get_optype(*n).is_load_constant(); + out_wire_value(n, 0.into(), PartialValue::from_load_constant(hugr, *n)) <-- load_constant_node(n); + + relation make_tuple_node(Node); + make_tuple_node(n) <-- node(n), if hugr.get_optype(*n).is_make_tuple(); + + out_wire_value(n,0.into(), PartialValue::tuple_from_value_row(vs)) <-- make_tuple_node(n), node_in_value_row(n, ?vs); + }; +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 803196144..eab13b21e 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod const_fold2; mod half_node; pub mod merge_bbs; pub mod nest_cfgs; From a00c46a7e8434f12b972d0fc67e32581e114da8b Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Jun 2024 09:20:35 +0100 Subject: [PATCH 02/12] wip --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/const_fold2/datalog.rs | 173 +++++++++++++++---------- 2 files changed, 109 insertions(+), 65 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 4dc128a78..4ad98e7ab 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -19,6 +19,7 @@ lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } ascent = "0.6.0" +either = "*" [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index b2e717108..f11646ce5 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -18,48 +18,60 @@ impl Hash for HashableHashMap { } } -#[derive(PartialEq, Clone, Eq, Hash)] -enum PartialValue { - Bottom(Type), - Value(Node, Type), - PartialSum(HashableHashMap>, SumType), - Top(Type), -} +struct ValueCache(HashMap); -impl PartialValue { - fn get_type(&self) -> Type { - match self { - PartialValue::Bottom(t) => t.clone(), - PartialValue::Value(_, t) => t.clone(), - PartialValue::PartialSum(_, t) => t.clone().into(), - PartialValue::Top(t) => t.clone(), - } +impl ValueCache { + fn new() -> Self { + Self(HashMap::new()) } - fn top_from_hugr(hugr: &impl HugrView, node: Node, port: OutgoingPort) -> Self { - Self::Top( - hugr.signature(node) - .unwrap() - .out_port_type(port) - .unwrap() - .clone(), - ) + fn get(&mut self, node: Node, value: &Value) -> ValueHandle { + self.0.entry(node).or_insert_with(|| value.clone()); + ValueHandle(node) } +} + +#[derive(PartialEq,Eq,Clone,Hash)] +struct ValueHandle(Node); + +impl ValueHandle { + fn new(node: Node) -> Self { + Self(node) + } +} - fn from_load_constant(hugr: &impl HugrView, node: Node) -> Self { +#[derive(PartialEq, Clone, Eq, Hash)] +enum PartialValue { + Bottom, + Value(ValueHandle), + PartialSum(HashableHashMap>), + Top, +} + +impl PartialValue { + const BOTTOM: Self = Self::Bottom; + const BOTTOM_REF: &'static Self = &Self::BOTTOM; + fn from_load_constant(cache: &mut ValueCache, hugr: &impl HugrView, node: Node) -> Self { let load_op = hugr.get_optype(node).as_load_constant().unwrap(); let const_node = hugr .single_linked_output(node, load_op.constant_port()) .unwrap() .0; let const_op = hugr.get_optype(const_node).as_const().unwrap(); - Self::Value(const_node, const_op.get_type()) + Self::Value(cache.get(const_node, const_op.value())) } fn tuple_from_value_row(r: &ValueRow) -> Self { - unimplemented!() + if !r.initialised() { + return Self::Top + } + match r { + ValueRow::Bottom => Self::Bottom, + ValueRow::Values(vs) => { + PartialValue::PartialSum(HashableHashMap([(0usize, vs.clone())].into_iter().collect())) + } + } } - } impl PartialOrd for PartialValue { @@ -81,42 +93,37 @@ impl Lattice for PartialValue { } fn join_mut(&mut self, other: Self) -> bool { - debug_assert_eq!(self.get_type(), other.get_type()); match (self, other) { - (Self::Bottom(_), _) => false, - (s, rhs @ Self::Bottom(_)) => { - *s = rhs; + (Self::Bottom, _) => false, + (s, Self::Bottom) => { + *s = Self::Bottom; true } - (_, Self::Top(_)) => false, - (s @ Self::Top(_), x) => { - *s = x; + (_, Self::Top) => false, + (s @ Self::Top, x) => { + *s = Self::Top; true } - (Self::Value(n1, t), Self::Value(n2, _)) if n1 == &n2 => false, + (Self::Value(h1), Self::Value(h2)) if h1 == &h2 => false, ( - Self::PartialSum(HashableHashMap(hm1), t), - Self::PartialSum(HashableHashMap(hm2), _), + Self::PartialSum(HashableHashMap(hm1)), + Self::PartialSum(HashableHashMap(hm2)) ) => { let mut changed = false; for (k, v) in hm2 { - let row = hm1.entry(k).or_insert_with(|| { + if let Some(row) = hm1.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + hm1.insert(k, v); changed = true; - t.get_variant(k) - .unwrap() - .iter() - .cloned() - .map(Self::Top) - .collect_vec() - }); - for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { - changed |= lhs.join_mut(rhs); } } changed } (s, _) => { - *s = Self::Bottom(s.get_type()); + *s = Self::Bottom; true } } @@ -133,25 +140,48 @@ enum ValueRow { Bottom, } + impl ValueRow { fn into_partial_value(self) -> PartialValue { todo!() } - fn new(tr: &TypeRow) -> Self { - Self::Values(tr.iter().cloned().map(PartialValue::Top).collect_vec()) + fn new(len: usize) -> Self { + Self::Values(vec![PartialValue::Top; len]) } - fn singleton(tr: &TypeRow, idx: usize, v: PartialValue) -> Self { - let mut r = Self::new(tr); + fn singleton(len: usize, idx: usize, v: PartialValue) -> Self { + assert!(idx < len); + let mut r = Self::new(len); if let Self::Values(vec) = &mut r { vec[idx] = v; } r } - fn iter(&self) -> impl Iterator { - std::iter::empty() + fn singleton_from_row(r: &TypeRow, idx: usize, v: PartialValue) -> Self { + Self::singleton(r.len(),idx,v) + } + + fn top_from_row(r: &TypeRow) -> Self { + Self::new(r.len()) + } + + fn iter<'a>(&'a self, h: &'a impl HugrView, n: Node) -> impl Iterator + 'a { + match self { + Self::Values(v) => { + either::Either::Left(zip_eq(h.node_inputs(n), v.iter())) + } + Self::Bottom => either::Either::Right(h.node_inputs(n).map(|x| (x,PartialValue::BOTTOM_REF))) + } + } + + fn initialised(&self) -> bool { + if let Self::Values(v) = self { + v.iter().all(|x| x != &PartialValue::Top) + } else { + true + } } } @@ -201,9 +231,18 @@ fn node_in_value_row<'a>( std::iter::empty() } -fn tc(hugr: &impl HugrView, node: Node) { +fn tc(hugr: &impl HugrView, node: Node) -> Vec<(Node, OutgoingPort, PartialValue)> { assert!(OpTag::DataflowParent.is_superset(hugr.get_optype(node).tag())); let d = DescendantsGraph::<'_, Node>::try_new(hugr, node).unwrap(); + let mut cache = ValueCache::new(); + + let singleton_in_row = |n: &Node, ip: &IncomingPort, v: &PartialValue| -> ValueRow { + ValueRow::singleton_from_row(&hugr.signature(*n).unwrap().input, ip.index(), v.clone()) + }; + + let top_row = |n: &Node| -> ValueRow { + ValueRow::top_from_row(&hugr.signature(*n).unwrap().input) + }; ascent_run! { relation node(Node) = d.nodes().map(|x| (x,)).collect_vec(); @@ -213,25 +252,29 @@ fn tc(hugr: &impl HugrView, node: Node) { relation out_wire(Node, OutgoingPort); out_wire(n,p) <-- node(n), for p in d.node_outputs(*n); - lattice node_in_value_row(Node, ValueRow); - node_in_value_row(n, ValueRow::new(&hugr.signature(*n).unwrap().input)) <-- node(n); - lattice out_wire_value(Node, OutgoingPort, PartialValue); - out_wire_value(n,p, PartialValue::top_from_hugr(hugr,*n,*p)) <-- out_wire(n,p); + out_wire_value(n,p, PartialValue::Top) <-- out_wire(n,p); - node_in_value_row(n,ValueRow::singleton(&hugr.signature(*n).unwrap().input, ip.index(), v.clone())) <-- in_wire(n, ip), - if let Some((m,op)) = hugr.single_linked_output(*n, *ip), out_wire_value(m, op, ?v); + lattice node_in_value_row(Node, ValueRow); + node_in_value_row(n, top_row(n)) <-- node(n); + node_in_value_row(n, singleton_in_row(n,ip,v)) <-- in_wire(n, ip), + if let Some((m,op)) = hugr.single_linked_output(*n, *ip), + out_wire_value(m, op, v); lattice in_wire_value(Node, IncomingPort, PartialValue); - in_wire_value(n,p,v) <-- node_in_value_row(n, ?vr), for (p,v) in vr.iter(); + in_wire_value(n,p,v) <-- node_in_value_row(n, vr), for (p,v) in vr.iter(hugr,*n); relation load_constant_node(Node); load_constant_node(n) <-- node(n), if hugr.get_optype(*n).is_load_constant(); - out_wire_value(n, 0.into(), PartialValue::from_load_constant(hugr, *n)) <-- load_constant_node(n); + + out_wire_value(n, 0.into(), PartialValue::from_load_constant(&mut cache, hugr, *n)) <-- + load_constant_node(n); relation make_tuple_node(Node); make_tuple_node(n) <-- node(n), if hugr.get_optype(*n).is_make_tuple(); - out_wire_value(n,0.into(), PartialValue::tuple_from_value_row(vs)) <-- make_tuple_node(n), node_in_value_row(n, ?vs); - }; + out_wire_value(n,0.into(), PartialValue::tuple_from_value_row(vs)) <-- + make_tuple_node(n), node_in_value_row(n, vs); + + }.out_wire_value } From 39f6826d6cecc3a68afcc21b8220a00f5c9b4adc Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Jun 2024 05:46:04 +0100 Subject: [PATCH 03/12] wip --- hugr-core/src/ops/constant.rs | 21 +- hugr-passes/src/const_fold2/datalog.rs | 625 ++++++++++++++---- .../src/const_fold2/datalog/context.rs | 154 +++++ 3 files changed, 658 insertions(+), 142 deletions(-) create mode 100644 hugr-passes/src/const_fold2/datalog/context.rs diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 5f059f4f0..f72e95a6d 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -98,7 +98,7 @@ impl AsRef for Const { } } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[serde(tag = "v")] /// A value that can be stored as a static constant. Representing core types and /// extension types. @@ -136,6 +136,25 @@ pub enum Value { }, } +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + let cmp_tuple_sum = |tuple_values: &[Value], tag: usize, sum_values: &[Value], sum_type: &SumType| { + tag == 0 && sum_type.num_variants() == 1 && tuple_values == sum_values + }; + match (self, other) { + (Self::Extension { e: e1 }, Self::Extension { e: e2 }) => e1 == e2, + (Self::Function { hugr: h1 }, Self::Function { hugr: h2 }) => h1 == h2, + (Self::Tuple { vs: v1 }, Self::Tuple { vs: v2 }) => v1 == v2, + (Self::Sum { tag: t1, values: v1, sum_type: s1 }, Self::Sum { tag: t2, values: v2, sum_type: s2 }) => { + t1 == t2 && v1 == v2 && s1 == s2 + } + (Self::Tuple { vs }, Self::Sum {tag, values, sum_type}) | + (Self::Sum {tag, values, sum_type}, Self::Tuple { vs }) => cmp_tuple_sum(vs, *tag, values, sum_type), + _ => false, + } + } +} + /// An opaque newtype around a [`Box`](CustomConst). /// /// This type has special serialization behaviour in order to support diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index f11646ce5..5edc2e467 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,14 +1,22 @@ use std::hash::{Hash, Hasher}; use ascent::{ascent_run, Lattice}; +use either::Either; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::{OpTag, OpTrait, Value}; -use hugr_core::types::{SumType, Type, TypeRow}; +use hugr_core::types::{FunctionType, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use itertools::{zip_eq, Itertools}; use std::collections::HashMap; +use std::sync::Arc; -#[derive(PartialEq, Clone, Eq)] +mod context; + +use context::DataflowContext; + +use self::context::ValueHandle; + +#[derive(PartialEq, Clone, Eq, Debug)] struct HashableHashMap(HashMap); impl Hash for HashableHashMap { @@ -18,58 +26,156 @@ impl Hash for HashableHashMap { } } -struct ValueCache(HashMap); - -impl ValueCache { - fn new() -> Self { - Self(HashMap::new()) - } - - fn get(&mut self, node: Node, value: &Value) -> ValueHandle { - self.0.entry(node).or_insert_with(|| value.clone()); - ValueHandle(node) - } -} - -#[derive(PartialEq,Eq,Clone,Hash)] -struct ValueHandle(Node); - -impl ValueHandle { - fn new(node: Node) -> Self { - Self(node) - } -} - -#[derive(PartialEq, Clone, Eq, Hash)] +#[derive(PartialEq, Clone, Eq, Hash, Debug)] enum PartialValue { Bottom, - Value(ValueHandle), + Value(context::ValueHandle), PartialSum(HashableHashMap>), Top, } +impl From for PartialValue { + fn from(v: ValueHandle) -> Self { + match v.value() { + Value::Tuple { vs } => { + let vec = (0..vs.len()).map(|i| PartialValue::from(v.index(i)).into()).collect(); + Self::PartialSum(HashableHashMap([(0, vec)].into_iter().collect())) + } + Value::Sum { tag, values, .. } => { + let vec = (0..values.len()).map(|i| PartialValue::from(v.index(i)).into()).collect(); + Self::PartialSum(HashableHashMap([(*tag, vec)].into_iter().collect())) + } + _ => Self::Value(v) + } + } +} + impl PartialValue { const BOTTOM: Self = Self::Bottom; const BOTTOM_REF: &'static Self = &Self::BOTTOM; - fn from_load_constant(cache: &mut ValueCache, hugr: &impl HugrView, node: Node) -> Self { - let load_op = hugr.get_optype(node).as_load_constant().unwrap(); - let const_node = hugr + + fn initialised(&self) -> bool { + !self.is_top() + } + + fn is_top(&self) -> bool { self == &PartialValue::Top } + + fn from_load_constant<'a, H: HugrView>( + context: &context::DataflowContext<'a, H>, + node: Node, + ) -> Self { + let load_op = context.hugr().get_optype(node).as_load_constant().unwrap(); + let const_node = context + .hugr() .single_linked_output(node, load_op.constant_port()) .unwrap() .0; - let const_op = hugr.get_optype(const_node).as_const().unwrap(); - Self::Value(cache.get(const_node, const_op.value())) + let const_op = context.hugr().get_optype(const_node).as_const().unwrap(); + context.value_handle(const_node, const_op.value()).into() } fn tuple_from_value_row(r: &ValueRow) -> Self { - if !r.initialised() { - return Self::Top + // if !r.initialised() { + // return Self::Top; + // } + PartialValue::PartialSum(HashableHashMap([(0usize, r.0.clone())].into_iter().collect())) + } + + fn tuple_field_value(&self, idx: usize) -> Self { + match self { + Self::Top => Self::Top, + Self::PartialSum(HashableHashMap(hm)) => { + if let Ok((0, row)) = hm.iter().exactly_one() { + assert!(row.len() > idx); + row[idx].clone() + } else { + Self::Bottom + } + }, + Self::Value(v) => Self::Value(v.index(idx)), + _ => Self::Bottom } - match r { - ValueRow::Bottom => Self::Bottom, - ValueRow::Values(vs) => { - PartialValue::PartialSum(HashableHashMap([(0usize, vs.clone())].into_iter().collect())) - } + } + + fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + match self { + Self::Bottom => Self::Bottom, + Self::PartialSum(HashableHashMap(hm)) => { + if let Some(row) = hm.get(&variant) { + assert!(row.len() > idx); + row[idx].clone() + } else { + // We must return top. if self were to gain this variant, we would return the element of that variant. + // We must ensure that the value return now is <= that future value + Self::Top + } + }, + Self::Value(v) if v.tag() == variant => { + Self::Value(v.index(idx)) + }, + _ => Self::Top + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let r = match self { + Self::Value(v) => v.value().clone(), + Self::PartialSum(HashableHashMap(hm)) => { + let err = |hm| Err(Self::PartialSum(HashableHashMap(hm))); + let Ok((k,v)) = hm.iter().exactly_one() else { + return err(hm); + }; + let TypeEnum::Sum(st) = typ.as_type_enum() else { + return err(hm); + }; + let Some(r) = st.get_variant(*k) else { + return err(hm); + }; + if v.len() != r.len() { + return err(hm); + } + + let Ok(vs) = zip_eq(v.into_iter(), r.into_iter()).map(|(v,t)| v.clone().try_into_value(t)).collect::, _>>() else { + return err(hm); + }; + + Value::sum(*k, vs, st.clone()).map_err(|_| Self::PartialSum(HashableHashMap(hm)))? + }, + x => Err(x)? + + }; + assert_eq!(typ, &r.get_type()); + Ok(r) + } + + fn join_value_handle(&mut self, vh: ValueHandle) -> bool { + let mut new_self = self; + match &mut new_self { + Self::Bottom => { + false + }, + s@Self::Value(_) => { + let Self::Value(v) = *s else { unreachable!() }; + if v == &vh { + false + } else { + **s = Self::Bottom; + true + } + }, + s@Self::PartialSum(_) => { + match vh.into() { + Self::Value(_) => { + **s = Self::Bottom; + true + } + other => s.join_mut(other) + } + }, + s@Self::Top => { + **s = vh.into(); + true + } } } } @@ -93,23 +199,25 @@ impl Lattice for PartialValue { } fn join_mut(&mut self, other: Self) -> bool { - match (self, other) { + // println!("join {self:?}\n{:?}", &other); + let mut s = self; + let changed = match (&mut s, other) { (Self::Bottom, _) => false, - (s, Self::Bottom) => { - *s = Self::Bottom; + (s, other@Self::Bottom) => { + **s = other; true - } + }, (_, Self::Top) => false, - (s @ Self::Top, x) => { - *s = Self::Top; + (s@Self::Top, other) => { + **s = other; true } - (Self::Value(h1), Self::Value(h2)) if h1 == &h2 => false, - ( - Self::PartialSum(HashableHashMap(hm1)), - Self::PartialSum(HashableHashMap(hm2)) - ) => { + (Self::Value(h1), Self::Value(h2)) if h1 == &h2 || h1.value() == h2.value() => { + false + } + (s@Self::PartialSum(_), Self::PartialSum(HashableHashMap(hm2))) => { let mut changed = false; + let Self::PartialSum(HashableHashMap(hm1)) = *s else { unreachable!() }; for (k, v) in hm2 { if let Some(row) = hm1.get_mut(&k) { for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { @@ -122,11 +230,24 @@ impl Lattice for PartialValue { } changed } - (s, _) => { - *s = Self::Bottom; - true + (s@Self::Value(_), other@Self::PartialSum(_)) => { + let mut old_self = other; + std::mem::swap(*s, &mut old_self); + let Self::Value(h) = old_self else { unreachable!() }; + s.join_value_handle(h) } - } + (s@Self::PartialSum(_), Self::Value(h)) => { + s.join_value_handle(h) + } + (s,_) => { + **s = Self::Bottom; + false + } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed } } @@ -135,58 +256,47 @@ impl Lattice for PartialValue { // } #[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] -enum ValueRow { - Values(Vec), - Bottom, -} - +struct ValueRow(Vec); impl ValueRow { - fn into_partial_value(self) -> PartialValue { - todo!() - } + // fn into_partial_value(self) -> PartialValue { + // todo!() + // } fn new(len: usize) -> Self { - Self::Values(vec![PartialValue::Top; len]) + Self(vec![PartialValue::Top; len]) } fn singleton(len: usize, idx: usize, v: PartialValue) -> Self { assert!(idx < len); let mut r = Self::new(len); - if let Self::Values(vec) = &mut r { - vec[idx] = v; - } + r.0[idx] = v; r } fn singleton_from_row(r: &TypeRow, idx: usize, v: PartialValue) -> Self { - Self::singleton(r.len(),idx,v) + Self::singleton(r.len(), idx, v) } fn top_from_row(r: &TypeRow) -> Self { Self::new(r.len()) } - fn iter<'a>(&'a self, h: &'a impl HugrView, n: Node) -> impl Iterator + 'a { - match self { - Self::Values(v) => { - either::Either::Left(zip_eq(h.node_inputs(n), v.iter())) - } - Self::Bottom => either::Either::Right(h.node_inputs(n).map(|x| (x,PartialValue::BOTTOM_REF))) - } + fn iter<'b, 'a, H: HugrView>( + &'b self, + context: &'b Ctx<'a,H>, + n: Node, + ) -> impl Iterator + 'b { + zip_eq(value_inputs(context, n), self.0.iter()) } fn initialised(&self) -> bool { - if let Self::Values(v) = self { - v.iter().all(|x| x != &PartialValue::Top) - } else { - true - } + self.0.iter().all(|x| x != &PartialValue::Top) } } impl Lattice for ValueRow { - fn meet(self, other: Self) -> Self { + fn meet(self, _other: Self) -> Self { todo!() } @@ -196,85 +306,318 @@ impl Lattice for ValueRow { } fn join_mut(&mut self, other: Self) -> bool { - match (self, other) { - (Self::Bottom, _) => false, - (s, o @ Self::Bottom) => { - *s = o; - true - } - (s, Self::Values(vs2)) => { - let (b, r) = if let Self::Values(vs1) = s { - if vs1.len() == vs2.len() { - let mut changed = false; - for (v1, v2) in zip_eq(vs1.iter_mut(), vs2.into_iter()) { - changed |= v1.join_mut(v2); - } - (false, changed) - } else { - (true, true) - } - } else { - panic!("impossible") - }; - if b { - *s = Self::Bottom; - } - r - } + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); } + changed } } -fn node_in_value_row<'a>( - ins: impl Iterator, -) -> impl Iterator { - std::iter::empty() -} -fn tc(hugr: &impl HugrView, node: Node) -> Vec<(Node, OutgoingPort, PartialValue)> { - assert!(OpTag::DataflowParent.is_superset(hugr.get_optype(node).tag())); - let d = DescendantsGraph::<'_, Node>::try_new(hugr, node).unwrap(); - let mut cache = ValueCache::new(); +type ArcCtx<'a, H: HugrView> = Arc>; +type Ctx<'a, H: HugrView> = DataflowContext<'a,H>; + +fn top_row<'a, H: HugrView>(context: &Ctx<'a, H>, n: Node) -> ValueRow { + if let Some(sig) = context.hugr().signature(n) { + ValueRow::new(sig.input_count()) + } else { + ValueRow::new(0) + } +} - let singleton_in_row = |n: &Node, ip: &IncomingPort, v: &PartialValue| -> ValueRow { - ValueRow::singleton_from_row(&hugr.signature(*n).unwrap().input, ip.index(), v.clone()) +fn singleton_in_row<'a, H: HugrView>( + context: &Ctx<'a, H>, + n: &Node, + ip: &IncomingPort, + v: &PartialValue, +) -> ValueRow { + let Some(sig) = context.hugr().signature(*n) else { + panic!("dougrulz"); }; + if sig.input_count() <= ip.index() { + panic!("bad port index: {} >= {}: {}", ip.index(), sig.input_count(), context.hugr().get_optype(*n).description()); + } + ValueRow::singleton_from_row( + &context.hugr().signature(*n).unwrap().input, + ip.index(), + v.clone(), + ) +} + +#[derive(Debug,Clone,Copy,PartialEq,Eq,Hash)] +enum IO { Input, Output } + +fn value_inputs<'a,H: HugrView>(context: &Ctx<'a, H>, n: Node) -> impl Iterator + 'a { + context.hugr().in_value_types(n).map(|x| x.0) +} + +fn value_outputs<'a,H: HugrView>(context: &Ctx<'a, H>, n: Node) -> impl Iterator + 'a { + context.hugr().out_value_types(n).map(|x| x.0) +} - let top_row = |n: &Node| -> ValueRow { - ValueRow::top_from_row(&hugr.signature(*n).unwrap().input) +fn tail_loop_worker<'b, 'a,H: HugrView>(context: &Ctx<'a, H>, n: Node, output_p: IncomingPort, control_variant: usize, v: &'b PartialValue) -> impl Iterator + 'b { + let tail_loop_op = context.get_optype(n).as_tail_loop().unwrap(); + let num_variant_vals = if control_variant == 0 { + tail_loop_op.just_inputs.len() + } else { + tail_loop_op.just_outputs.len() }; - ascent_run! { - relation node(Node) = d.nodes().map(|x| (x,)).collect_vec(); + if output_p.index() == 0 { + Either::Left((0..num_variant_vals).map(move |i| (i.into(), v.variant_field_value(control_variant, i)))) + } else { + Either::Right(std::iter::once(((num_variant_vals + output_p.index()).into(), v.clone()))) + } +} + +ascent::ascent! { + struct Dataflow<'a, H: HugrView>; + relation context(ArcCtx<'a, H>); + relation node(ArcCtx<'a, H>, Node); + relation in_wire(ArcCtx<'a,H>, Node, IncomingPort); + relation out_wire(ArcCtx<'a,H>, Node, OutgoingPort); + lattice out_wire_value(ArcCtx<'a,H>, Node, OutgoingPort, PartialValue); + lattice node_in_value_row(ArcCtx<'a,H>, Node, ValueRow); + lattice in_wire_value(ArcCtx<'a,H>, Node, IncomingPort, PartialValue); + + node(c, n) <-- context(c), for n in c.nodes(); + + in_wire(c, n,p) <-- node(c, n), for p in value_inputs(c, *n); + + out_wire(c, n,p) <-- node(c, n), for p in value_outputs(c, *n); + + // All out wire values are initialised to Top. If any value is Top after + // running we can infer that execution never reaches that value. + out_wire_value(c, n,p, PartialValue::Top) <-- out_wire(c, n,p); + + in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), + if let Some((m,op)) = c.single_linked_output(*n, *ip), + out_wire_value(c, m, op, v); + + + node_in_value_row(c, n, top_row(c, *n)) <-- node(c, n); + node_in_value_row(c, n, singleton_in_row(c, n, p, v)) <-- in_wire_value(c, n, p, v); + + // LoadConstant + relation load_constant_node(ArcCtx<'a, H>, Node); + load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); + + out_wire_value(c, n, 0.into(), PartialValue::from_load_constant(c, *n)) <-- + load_constant_node(c, n); + + // MakeTuple + relation make_tuple_node(ArcCtx<'a,H>, Node); + make_tuple_node(c, n) <-- node(c, n), if c.get_optype(*n).is_make_tuple(); + + out_wire_value(c, n, 0.into(), PartialValue::tuple_from_value_row(vs)) <-- + make_tuple_node(c, n), node_in_value_row(c, n, vs); + + // UnpackTuple + relation unpack_tuple_node(ArcCtx<'a, H>, Node); + unpack_tuple_node(c,n) <-- node(c, n), if c.get_optype(*n).is_unpack_tuple(); + + out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- unpack_tuple_node(c, n), in_wire_value(c, n, IncomingPort::from(0), v), out_wire(c, n, p); + + // DFG + relation dfg_node(ArcCtx<'a, H>, Node); + dfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_dfg(); + relation dfg_io_node(ArcCtx<'a, H>, Node, Node, IO); + dfg_io_node(c,dfg,n,io) <-- dfg_node(c,dfg), + if let Some([i,o]) = c.get_io(*dfg), + for (n, io) in [(i, IO::Input), (o, IO::Output)]; + + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- + dfg_io_node(c,dfg,i, IO::Input), in_wire_value(c, dfg, p, v); + out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- + dfg_io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v); - relation in_wire(Node, IncomingPort); - in_wire(n,p) <-- node(n), for p in d.node_inputs(*n); - relation out_wire(Node, OutgoingPort); - out_wire(n,p) <-- node(n), for p in d.node_outputs(*n); + // TailLoop + relation tail_loop_node(ArcCtx<'a, H>, Node); + tail_loop_node(c,n) <-- node(c, n), if c.get_optype(*n).is_tail_loop(); + relation tail_loop_io_node(ArcCtx<'a, H>, Node, Node, IO); + tail_loop_io_node(c,tl,n, io) <-- tail_loop_node(c,tl), + if let Some([i,o]) = c.get_io(*tl), + for (n,io) in [(i,IO::Input), (o, IO::Output)]; - lattice out_wire_value(Node, OutgoingPort, PartialValue); - out_wire_value(n,p, PartialValue::Top) <-- out_wire(n,p); + // inputs of tail loop propagate to Input node of child region + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- + tail_loop_io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); + // Output node of child region propagate to Input node of child region + out_wire_value(c, i, input_p, v) <-- + tail_loop_io_node(c,tl,i, IO::Input), + tail_loop_io_node(c,tl,o, IO::Output), + in_wire_value(c, o, output_p, output_v), + for (input_p, v) in tail_loop_worker(c, *tl, *output_p, 0, output_v); + // Output node of child region propagate to outputs of tail loop + out_wire_value(c, tl, p, v) <-- + tail_loop_io_node(c,tl,o, IO::Output), + in_wire_value(c, o, output_p, output_v), + for (p, v) in tail_loop_worker(c, *tl, *output_p, 1, output_v); - lattice node_in_value_row(Node, ValueRow); - node_in_value_row(n, top_row(n)) <-- node(n); - node_in_value_row(n, singleton_in_row(n,ip,v)) <-- in_wire(n, ip), - if let Some((m,op)) = hugr.single_linked_output(*n, *ip), - out_wire_value(m, op, v); - lattice in_wire_value(Node, IncomingPort, PartialValue); - in_wire_value(n,p,v) <-- node_in_value_row(n, vr), for (p,v) in vr.iter(hugr,*n); +} + +impl<'a, H: HugrView> Dataflow<'a, H> { + pub fn new() -> Self { + Self::default() + } + + pub fn run_hugr(&mut self, hugr: &'a H) -> ArcCtx<'a,H> { + let context = context::DataflowContext::new(hugr); + self.context.push((context.clone(),)); + self.run(); + context + } + + pub fn read_out_wire_partial_value(&self, context: &Ctx<'a,H>, w: Wire) -> Option { + self.out_wire_value.iter().find_map(|(c,n,p,v)| (c.as_ref() == context && &w.node() == n && &w.source() == p).then_some(v.clone())) + } + + pub fn read_out_wire_value(&self, context: &Ctx<'a,H>, w: Wire) -> Option { + // dbg!(&w); + let pv = self.read_out_wire_partial_value(context, w)?; + // dbg!(&pv); + let (_, typ) = context.hugr().out_value_types(w.node()).find(|(p,_)| *p == w.source()).unwrap(); + pv.try_into_value(&typ).ok() + } +} + +#[cfg(test)] +mod test { + use hugr_core::{builder::{DFGBuilder, Dataflow, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, EMPTY_REG}, ops::{UnpackTuple, Value}, type_row, types::{FunctionType, SumType}}; + + use crate::const_fold2::datalog::PartialValue; + + + #[test] + fn test_make_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1,v2]).unwrap(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = super::Dataflow::new(); + let c = machine.run_hugr(&hugr); + + let x = machine.read_out_wire_value(&c, v3).unwrap(); + assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); + } + + #[test] + fn test_unpack_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1,v2]).unwrap(); + let [o1,o2] = builder.add_dataflow_op(UnpackTuple::new(type_row![BOOL_T,BOOL_T]), [v3]).unwrap().outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = super::Dataflow::new(); + let c = machine.run_hugr(&hugr); + + let o1_r = machine.read_out_wire_value(&c, o1).unwrap(); + assert_eq!(o1_r, Value::false_val()); + let o2_r = machine.read_out_wire_value(&c, o2).unwrap(); + assert_eq!(o2_r, Value::true_val()); + } - relation load_constant_node(Node); - load_constant_node(n) <-- node(n), if hugr.get_optype(*n).is_load_constant(); + #[test] + fn test_unpack_const() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); + let [o] = builder.add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]).unwrap().outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - out_wire_value(n, 0.into(), PartialValue::from_load_constant(&mut cache, hugr, *n)) <-- - load_constant_node(n); + let mut machine = super::Dataflow::new(); + let c = machine.run_hugr(&hugr); - relation make_tuple_node(Node); - make_tuple_node(n) <-- node(n), if hugr.get_optype(*n).is_make_tuple(); + let o_r = machine.read_out_wire_value(&c, o).unwrap(); + assert_eq!(o_r, Value::true_val()); + } - out_wire_value(n,0.into(), PartialValue::tuple_from_value_row(vs)) <-- - make_tuple_node(n), node_in_value_row(n, vs); + #[test] + fn test_tail_loop_never_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_v = Value::unit_sum(3,6).unwrap(); + let r_w = builder.add_load_value(Value::sum(1, [r_v.clone()], SumType::new([type_row![], r_v.get_type().into()])).unwrap()); + let tlb = builder.tail_loop_builder([], [], vec![r_v.get_type()].into()).unwrap(); + let [tl_o] = tlb.finish_with_outputs(r_w, []).unwrap().outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = super::Dataflow::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + let o_r = machine.read_out_wire_value(&c, tl_o).unwrap(); + assert_eq!(o_r, r_v); + } - }.out_wire_value + #[test] + fn test_tail_loop_always_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_w = builder.add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let tlb = builder.tail_loop_builder([], [], vec![BOOL_T].into()).unwrap(); + let [tl_o] = tlb.finish_with_outputs(r_w, []).unwrap().outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = super::Dataflow::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + let o_r = machine.read_out_wire_partial_value(&c, tl_o).unwrap(); + assert_eq!(o_r, PartialValue::Top); + } } + +// fn tc(hugr: &impl HugrView, node: Node) -> Vec<(Node, OutgoingPort, PartialValue)> { +// assert!(OpTag::DataflowParent.is_superset(hugr.get_optype(node).tag())); +// let d = DescendantsGraph::<'_, Node>::try_new(hugr, node).unwrap(); +// let mut cache = ValueCache::new(); + +// let singleton_in_row = |n: &Node, ip: &IncomingPort, v: &PartialValue| -> ValueRow { +// ValueRow::singleton_from_row(&hugr.signature(*n).unwrap().input, ip.index(), v.clone()) +// }; + +// let top_row = |n: &Node| -> ValueRow { +// ValueRow::top_from_row(&hugr.signature(*n).unwrap().input) +// }; +// // ascent! { +// // relation node(Node) = d.nodes().map(|x| (x,)).collect_vec(); + +// // relation in_wire(Node, IncomingPort); +// // in_wire(n,p) <-- node(n), for p in d.node_inputs(*n); + +// // relation out_wire(Node, OutgoingPort); +// // out_wire(n,p) <-- node(n), for p in d.node_outputs(*n); + +// // lattice out_wire_value(Node, OutgoingPort, PartialValue); +// // out_wire_value(n,p, PartialValue::Top) <-- out_wire(n,p); + +// // lattice node_in_value_row(Node, ValueRow); +// // node_in_value_row(n, top_row(n)) <-- node(n); +// // node_in_value_row(n, singleton_in_row(n,ip,v)) <-- in_wire(n, ip), +// // if let Some((m,op)) = hugr.single_linked_output(*n, *ip), +// // out_wire_value(m, op, v); + +// // lattice in_wire_value(Node, IncomingPort, PartialValue); +// // in_wire_value(n,p,v) <-- node_in_value_row(n, vr), for (p,v) in vr.iter(hugr,*n); + +// // relation load_constant_node(Node); +// // load_constant_node(n) <-- node(n), if hugr.get_optype(*n).is_load_constant(); + +// // out_wire_value(n, 0.into(), PartialValue::from_load_constant(&mut cache, hugr, *n)) <-- +// // load_constant_node(n); + +// // relation make_tuple_node(Node); +// // make_tuple_node(n) <-- node(n), if hugr.get_optype(*n).is_make_tuple(); + +// // out_wire_value(n,0.into(), PartialValue::tuple_from_value_row(vs)) <-- +// // make_tuple_node(n), node_in_value_row(n, vs); + +// // }.out_wire_value +// } diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs new file mode 100644 index 000000000..d633d1e25 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -0,0 +1,154 @@ +use std::cell::RefCell; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use ascent::Lattice; + +use either::Either; +use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +#[derive(Clone, Debug)] +pub struct ValueHandle(Vec, Node, Arc); + +impl ValueHandle { + pub fn value(&self) -> &Value { + self.2.as_ref() + } + + pub fn tag(&self) -> usize { + match self.value() { + Value::Sum { tag, .. } => *tag, + Value::Tuple { .. } => 0, + _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value"), + } + } + + pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { + let vs = match self.value() { + Value::Sum { values, .. } => values, + Value::Tuple { vs, .. } => vs, + _ => panic!("ValueHandle::index called on non-Sum, non-Tuple value"), + }; + assert!(i < vs.len()); + let v = vs[i].clone().into(); + let mut is = self.0.clone(); + is.push(i); + Self(is, self.1, v) + } +} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + (&self.0, self.1) == (&other.0, other.1) + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + self.0.hash(state); + self.1.hash(state); + } +} + +impl Deref for ValueHandle { + type Target = Value; + + fn deref(&self) -> &Self::Target { + self.value() + } +} + +#[derive(Clone)] +pub struct ValueCache(HashMap>); + +impl ValueCache { + fn new() -> Self { + Self(HashMap::new()) + } + + fn get(&mut self, node: Node, value: &Value) -> ValueHandle { + let v = self.0.entry(node).or_insert_with(|| value.clone().into()).clone(); + ValueHandle(vec![], node, v) + } +} + + +static mut CONTEXT_ID: AtomicUsize = AtomicUsize::new(0); + +fn next_context_id() -> usize { + unsafe { CONTEXT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst) } +} + +pub struct DataflowContext<'a, H> { + id: usize, + hugr: &'a H, + cache: RefCell, +} + +impl<'a, H> DataflowContext<'a, H> { + pub fn new(hugr: &'a H) -> Arc { + Arc::new(Self { + id: next_context_id(), + hugr, + cache: ValueCache::new().into(), + }) + } + + pub fn value_handle(&self, node: Node, value: &Value) -> ValueHandle { + self.cache.borrow_mut().get(node, value) + } + + pub fn hugr(&self) -> &'a H { + self.hugr + } + + pub fn id(&self) -> usize { + self.id + } +} + +impl std::fmt::Debug for DataflowContext<'_, H> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DataflowContext({})", self.id) + } +} + +impl Hash for DataflowContext<'_, H> { + fn hash(&self, state: &mut I) { + self.id.hash(state); + } +} + +impl PartialEq for DataflowContext<'_, H> { + fn eq(&self, other: &usize) -> bool { + &self.id == other + } +} + +impl PartialEq for DataflowContext<'_, H> { + fn eq(&self, other: &Self) -> bool { + self == &other.id + } +} + +impl Eq for DataflowContext<'_, H> {} + +impl PartialOrd for DataflowContext<'_, H> { + fn partial_cmp(&self, _other: &Self) -> Option { + self.id.partial_cmp(&_other.id) + } +} + +impl<'a,H> Deref for DataflowContext<'a,H> { + type Target = H; + + fn deref(&self) -> &Self::Target { + self.hugr + } +} From 866bc5d38da1ac1f37f14e500c38cc1d32859806 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Jun 2024 06:28:04 +0100 Subject: [PATCH 04/12] move PartialValue into hugr-core --- hugr-core/src/extension.rs | 2 +- hugr-core/src/extension/const_fold.rs | 2 + .../src/extension/const_fold/partial_value.rs | 292 ++++++++++++ hugr-passes/Cargo.toml | 1 + hugr-passes/src/const_fold2/datalog.rs | 434 +++++++----------- .../src/const_fold2/datalog/context.rs | 54 +-- 6 files changed, 465 insertions(+), 320 deletions(-) create mode 100644 hugr-core/src/extension/const_fold/partial_value.rs diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 1ef9c50cd..90ffba912 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -36,7 +36,7 @@ mod const_fold; pub mod prelude; pub mod simple_op; pub mod validate; -pub use const_fold::{ConstFold, ConstFoldResult, Folder}; +pub use const_fold::{ConstFold, ConstFoldResult, Folder, PartialValue, ValueHandle, HashableHashMap}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; pub mod declarative; diff --git a/hugr-core/src/extension/const_fold.rs b/hugr-core/src/extension/const_fold.rs index a3aae93eb..4485a05b0 100644 --- a/hugr-core/src/extension/const_fold.rs +++ b/hugr-core/src/extension/const_fold.rs @@ -10,6 +10,8 @@ use crate::OutgoingPort; use crate::ops; +mod partial_value; +pub use partial_value::{PartialValue, ValueHandle, HashableHashMap}; /// Output of constant folding an operation, None indicates folding was either /// not possible or unsuccessful. An empty vector indicates folding was /// successful and no values are output. diff --git a/hugr-core/src/extension/const_fold/partial_value.rs b/hugr-core/src/extension/const_fold/partial_value.rs new file mode 100644 index 000000000..f7de384a4 --- /dev/null +++ b/hugr-core/src/extension/const_fold/partial_value.rs @@ -0,0 +1,292 @@ +use std::sync::Arc; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::collections::HashMap; + +use itertools::{zip_eq, Itertools as _}; + +use crate::ops::{OpTag, OpTrait, Value}; +use crate::types::{Type, TypeEnum}; +use crate::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +#[derive(Clone, Debug)] +pub struct ValueHandle(Vec, Node, Arc); + +impl ValueHandle { + pub fn new(node: Node, value: Arc) -> Self { + Self(vec![], node, value) + } + + pub fn value(&self) -> &Value { + self.2.as_ref() + } + + pub fn tag(&self) -> usize { + match self.value() { + Value::Sum { tag, .. } => *tag, + Value::Tuple { .. } => 0, + _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value"), + } + } + + pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { + let vs = match self.value() { + Value::Sum { values, .. } => values, + Value::Tuple { vs, .. } => vs, + _ => panic!("ValueHandle::index called on non-Sum, non-Tuple value"), + }; + assert!(i < vs.len()); + let v = vs[i].clone().into(); + let mut is = self.0.clone(); + is.push(i); + Self(is, self.1, v) + } +} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + (&self.0, self.1) == (&other.0, other.1) + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + self.0.hash(state); + self.1.hash(state); + } +} + +/// TODO this is dodgy +/// we do not hash or compare the value, just the key +/// this means two handles with different keys, but with the same value, will +/// not compare equal. +impl Deref for ValueHandle { + type Target = Value; + + fn deref(&self) -> &Self::Target { + self.value() + } +} + +/// TODO shouldn't be pub +#[derive(PartialEq, Clone, Eq, Debug)] +pub struct HashableHashMap(HashMap); + +impl Hash for HashableHashMap { + fn hash(&self, state: &mut H) { + self.0.keys().for_each(|k| k.hash(state)); + self.0.values().for_each(|v| v.hash(state)); + } +} + +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + Bottom, + Value(ValueHandle), + PartialSum(HashableHashMap>), + Top, +} + +impl From for PartialValue { + fn from(v: ValueHandle) -> Self { + match v.value() { + Value::Tuple { vs } => { + let vec = (0..vs.len()).map(|i| PartialValue::from(v.index(i)).into()).collect(); + Self::PartialSum(HashableHashMap([(0, vec)].into_iter().collect())) + } + Value::Sum { tag, values, .. } => { + let vec = (0..values.len()).map(|i| PartialValue::from(v.index(i)).into()).collect(); + Self::PartialSum(HashableHashMap([(*tag, vec)].into_iter().collect())) + } + _ => Self::Value(v) + } + } +} + +impl PartialValue { + const BOTTOM: Self = Self::Bottom; + const BOTTOM_REF: &'static Self = &Self::BOTTOM; + + fn initialised(&self) -> bool { + !self.is_top() + } + + fn is_top(&self) -> bool { self == &PartialValue::Top } + + + /// TODO docs + /// just delegate to variant_field_value + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0,idx) + } + + /// TODO docs + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + match self { + Self::Bottom => Self::Bottom, + Self::PartialSum(HashableHashMap(hm)) => { + if let Some(row) = hm.get(&variant) { + assert!(row.len() > idx); + row[idx].clone() + } else { + // We must return top. if self were to gain this variant, we would return the element of that variant. + // We must ensure that the value return now is <= that future value + Self::Top + } + }, + Self::Value(v) if v.tag() == variant => { + Self::Value(v.index(idx)) + }, + _ => Self::Top + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let r = match self { + Self::Value(v) => v.value().clone(), + Self::PartialSum(HashableHashMap(hm)) => { + let err = |hm| Err(Self::PartialSum(HashableHashMap(hm))); + let Ok((k,v)) = hm.iter().exactly_one() else { + return err(hm); + }; + let TypeEnum::Sum(st) = typ.as_type_enum() else { + return err(hm); + }; + let Some(r) = st.get_variant(*k) else { + return err(hm); + }; + if v.len() != r.len() { + return err(hm); + } + + let Ok(vs) = zip_eq(v.into_iter(), r.into_iter()).map(|(v,t)| v.clone().try_into_value(t)).collect::, _>>() else { + return err(hm); + }; + + Value::sum(*k, vs, st.clone()).map_err(|_| Self::PartialSum(HashableHashMap(hm)))? + }, + x => Err(x)? + + }; + assert_eq!(typ, &r.get_type()); + Ok(r) + } + + fn join_value_handle(&mut self, vh: ValueHandle) -> bool { + let mut new_self = self; + match &mut new_self { + Self::Bottom => { + false + }, + s@Self::Value(_) => { + let Self::Value(v) = *s else { unreachable!() }; + if v == &vh { + false + } else { + **s = Self::Bottom; + true + } + }, + s@Self::PartialSum(_) => { + match vh.into() { + Self::Value(_) => { + **s = Self::Bottom; + true + } + other => s.join_mut(other) + } + }, + s@Self::Top => { + **s = vh.into(); + true + } + } + } + + pub fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + pub fn join_mut(&mut self, other: Self) -> bool { + // println!("join {self:?}\n{:?}", &other); + let mut s = self; + let changed = match (&mut s, other) { + (Self::Bottom, _) => false, + (s, other@Self::Bottom) => { + **s = other; + true + }, + (_, Self::Top) => false, + (s@Self::Top, other) => { + **s = other; + true + } + (Self::Value(h1), Self::Value(h2)) if h1 == &h2 || h1.value() == h2.value() => { + false + } + (s@Self::PartialSum(_), Self::PartialSum(HashableHashMap(hm2))) => { + let mut changed = false; + let Self::PartialSum(HashableHashMap(hm1)) = *s else { unreachable!() }; + for (k, v) in hm2 { + if let Some(row) = hm1.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + hm1.insert(k, v); + changed = true; + } + } + changed + } + (s@Self::Value(_), other@Self::PartialSum(_)) => { + let mut old_self = other; + std::mem::swap(*s, &mut old_self); + let Self::Value(h) = old_self else { unreachable!() }; + s.join_value_handle(h) + } + (s@Self::PartialSum(_), Self::Value(h)) => { + s.join_value_handle(h) + } + (s,_) => { + **s = Self::Bottom; + false + } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn meet(mut self, other: Self) -> Self { + self.meet_mut(other) ; + self + } + + pub fn meet_mut(&mut self, _other: Self) -> bool { + todo!() + } + + pub fn top() -> Self { + Self::Top + } + + pub fn bottom() -> Self { + Self::Bottom + } + + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + Self::PartialSum(HashableHashMap([(tag, values.into_iter().collect())].into_iter().collect())) + } +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + // TODO we can do better + (self == other).then_some(std::cmp::Ordering::Equal) + } +} diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 4ad98e7ab..06d9975ea 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -20,6 +20,7 @@ paste = { workspace = true } thiserror = { workspace = true } ascent = "0.6.0" either = "*" +delegate = "*" [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 5edc2e467..b497d91ee 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,262 +1,79 @@ +use ascent::lattice::BoundedLattice; +use delegate::delegate; +use itertools::{zip_eq, Itertools}; +use std::collections::HashMap; use std::hash::{Hash, Hasher}; +use std::sync::Arc; use ascent::{ascent_run, Lattice}; use either::Either; +use hugr_core::extension::{HashableHashMap, PartialValue, ValueHandle}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::{OpTag, OpTrait, Value}; use hugr_core::types::{FunctionType, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; -use itertools::{zip_eq, Itertools}; -use std::collections::HashMap; -use std::sync::Arc; mod context; use context::DataflowContext; -use self::context::ValueHandle; +#[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] +struct PV(PartialValue); -#[derive(PartialEq, Clone, Eq, Debug)] -struct HashableHashMap(HashMap); - -impl Hash for HashableHashMap { - fn hash(&self, state: &mut H) { - self.0.keys().for_each(|k| k.hash(state)); - self.0.values().for_each(|v| v.hash(state)); +impl From for PV { + fn from(inner: PartialValue) -> Self { + Self(inner) } } -#[derive(PartialEq, Clone, Eq, Hash, Debug)] -enum PartialValue { - Bottom, - Value(context::ValueHandle), - PartialSum(HashableHashMap>), - Top, -} - -impl From for PartialValue { - fn from(v: ValueHandle) -> Self { - match v.value() { - Value::Tuple { vs } => { - let vec = (0..vs.len()).map(|i| PartialValue::from(v.index(i)).into()).collect(); - Self::PartialSum(HashableHashMap([(0, vec)].into_iter().collect())) - } - Value::Sum { tag, values, .. } => { - let vec = (0..values.len()).map(|i| PartialValue::from(v.index(i)).into()).collect(); - Self::PartialSum(HashableHashMap([(*tag, vec)].into_iter().collect())) - } - _ => Self::Value(v) - } - } -} - -impl PartialValue { - const BOTTOM: Self = Self::Bottom; - const BOTTOM_REF: &'static Self = &Self::BOTTOM; - - fn initialised(&self) -> bool { - !self.is_top() - } - - fn is_top(&self) -> bool { self == &PartialValue::Top } - - fn from_load_constant<'a, H: HugrView>( - context: &context::DataflowContext<'a, H>, - node: Node, - ) -> Self { - let load_op = context.hugr().get_optype(node).as_load_constant().unwrap(); - let const_node = context - .hugr() - .single_linked_output(node, load_op.constant_port()) - .unwrap() - .0; - let const_op = context.hugr().get_optype(const_node).as_const().unwrap(); - context.value_handle(const_node, const_op.value()).into() - } - - fn tuple_from_value_row(r: &ValueRow) -> Self { - // if !r.initialised() { - // return Self::Top; - // } - PartialValue::PartialSum(HashableHashMap([(0usize, r.0.clone())].into_iter().collect())) - } - +impl PV { fn tuple_field_value(&self, idx: usize) -> Self { - match self { - Self::Top => Self::Top, - Self::PartialSum(HashableHashMap(hm)) => { - if let Ok((0, row)) = hm.iter().exactly_one() { - assert!(row.len() > idx); - row[idx].clone() - } else { - Self::Bottom - } - }, - Self::Value(v) => Self::Value(v.index(idx)), - _ => Self::Bottom - } + self.0.tuple_field_value(idx).into() } + /// TODO the arguments here are not pretty, two usizes, better not mix them + /// up!!! fn variant_field_value(&self, variant: usize, idx: usize) -> Self { - match self { - Self::Bottom => Self::Bottom, - Self::PartialSum(HashableHashMap(hm)) => { - if let Some(row) = hm.get(&variant) { - assert!(row.len() > idx); - row[idx].clone() - } else { - // We must return top. if self were to gain this variant, we would return the element of that variant. - // We must ensure that the value return now is <= that future value - Self::Top - } - }, - Self::Value(v) if v.tag() == variant => { - Self::Value(v.index(idx)) - }, - _ => Self::Top - } + self.0.variant_field_value(variant, idx).into() } +} - pub fn try_into_value(self, typ: &Type) -> Result { - let r = match self { - Self::Value(v) => v.value().clone(), - Self::PartialSum(HashableHashMap(hm)) => { - let err = |hm| Err(Self::PartialSum(HashableHashMap(hm))); - let Ok((k,v)) = hm.iter().exactly_one() else { - return err(hm); - }; - let TypeEnum::Sum(st) = typ.as_type_enum() else { - return err(hm); - }; - let Some(r) = st.get_variant(*k) else { - return err(hm); - }; - if v.len() != r.len() { - return err(hm); - } - - let Ok(vs) = zip_eq(v.into_iter(), r.into_iter()).map(|(v,t)| v.clone().try_into_value(t)).collect::, _>>() else { - return err(hm); - }; - - Value::sum(*k, vs, st.clone()).map_err(|_| Self::PartialSum(HashableHashMap(hm)))? - }, - x => Err(x)? - - }; - assert_eq!(typ, &r.get_type()); - Ok(r) - } - - fn join_value_handle(&mut self, vh: ValueHandle) -> bool { - let mut new_self = self; - match &mut new_self { - Self::Bottom => { - false - }, - s@Self::Value(_) => { - let Self::Value(v) = *s else { unreachable!() }; - if v == &vh { - false - } else { - **s = Self::Bottom; - true - } - }, - s@Self::PartialSum(_) => { - match vh.into() { - Self::Value(_) => { - **s = Self::Bottom; - true - } - other => s.join_mut(other) - } - }, - s@Self::Top => { - **s = vh.into(); - true - } - } +impl From for PV { + fn from(inner: ValueHandle) -> Self { + Self(inner.into()) } } -impl PartialOrd for PartialValue { - fn partial_cmp(&self, other: &Self) -> Option { - // TODO we can do better - (self == other).then_some(std::cmp::Ordering::Equal) +impl Lattice for PV { + fn meet(self, other: Self) -> Self { + self.0.meet(other.0).into() } -} -impl Lattice for PartialValue { - fn meet(self, _other: Self) -> Self { - // should not be required - todo!() + fn meet_mut(&mut self, other: Self) -> bool { + self.0.meet_mut(other.0) } - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self + fn join(self, other: Self) -> Self { + self.0.join(other.0).into() } fn join_mut(&mut self, other: Self) -> bool { - // println!("join {self:?}\n{:?}", &other); - let mut s = self; - let changed = match (&mut s, other) { - (Self::Bottom, _) => false, - (s, other@Self::Bottom) => { - **s = other; - true - }, - (_, Self::Top) => false, - (s@Self::Top, other) => { - **s = other; - true - } - (Self::Value(h1), Self::Value(h2)) if h1 == &h2 || h1.value() == h2.value() => { - false - } - (s@Self::PartialSum(_), Self::PartialSum(HashableHashMap(hm2))) => { - let mut changed = false; - let Self::PartialSum(HashableHashMap(hm1)) = *s else { unreachable!() }; - for (k, v) in hm2 { - if let Some(row) = hm1.get_mut(&k) { - for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { - changed |= lhs.join_mut(rhs); - } - } else { - hm1.insert(k, v); - changed = true; - } - } - changed - } - (s@Self::Value(_), other@Self::PartialSum(_)) => { - let mut old_self = other; - std::mem::swap(*s, &mut old_self); - let Self::Value(h) = old_self else { unreachable!() }; - s.join_value_handle(h) - } - (s@Self::PartialSum(_), Self::Value(h)) => { - s.join_value_handle(h) - } - (s,_) => { - **s = Self::Bottom; - false - } - }; - // if changed { - // println!("join new self: {:?}", s); - // } - changed + self.0.join_mut(other.0) } } -// fn input_row<'a>(inp: impl Iterator) -> impl Iterator { -// todo!() -// } +impl BoundedLattice for PV { + fn bottom() -> Self { + PartialValue::bottom().into() + } + + fn top() -> Self { + PartialValue::top().into() + } +} #[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] -struct ValueRow(Vec); +struct ValueRow(Vec); impl ValueRow { // fn into_partial_value(self) -> PartialValue { @@ -264,17 +81,17 @@ impl ValueRow { // } fn new(len: usize) -> Self { - Self(vec![PartialValue::Top; len]) + Self(vec![PV::top(); len]) } - fn singleton(len: usize, idx: usize, v: PartialValue) -> Self { + fn singleton(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - fn singleton_from_row(r: &TypeRow, idx: usize, v: PartialValue) -> Self { + fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { Self::singleton(r.len(), idx, v) } @@ -284,14 +101,14 @@ impl ValueRow { fn iter<'b, 'a, H: HugrView>( &'b self, - context: &'b Ctx<'a,H>, + context: &'b Ctx<'a, H>, n: Node, - ) -> impl Iterator + 'b { + ) -> impl Iterator + 'b { zip_eq(value_inputs(context, n), self.0.iter()) } fn initialised(&self) -> bool { - self.0.iter().all(|x| x != &PartialValue::Top) + self.0.iter().all(|x| x != &PV::top()) } } @@ -315,9 +132,18 @@ impl Lattice for ValueRow { } } +impl IntoIterator for ValueRow { + type Item = PV; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} -type ArcCtx<'a, H: HugrView> = Arc>; -type Ctx<'a, H: HugrView> = DataflowContext<'a,H>; +type ArcCtx<'a, H: HugrView> = Arc>; +type Ctx<'a, H: HugrView> = DataflowContext<'a, H>; fn top_row<'a, H: HugrView>(context: &Ctx<'a, H>, n: Node) -> ValueRow { if let Some(sig) = context.hugr().signature(n) { @@ -331,33 +157,68 @@ fn singleton_in_row<'a, H: HugrView>( context: &Ctx<'a, H>, n: &Node, ip: &IncomingPort, - v: &PartialValue, + v: PV, ) -> ValueRow { let Some(sig) = context.hugr().signature(*n) else { panic!("dougrulz"); }; if sig.input_count() <= ip.index() { - panic!("bad port index: {} >= {}: {}", ip.index(), sig.input_count(), context.hugr().get_optype(*n).description()); - } - ValueRow::singleton_from_row( - &context.hugr().signature(*n).unwrap().input, - ip.index(), - v.clone(), - ) + panic!( + "bad port index: {} >= {}: {}", + ip.index(), + sig.input_count(), + context.hugr().get_optype(*n).description() + ); + } + ValueRow::singleton_from_row(&context.hugr().signature(*n).unwrap().input, ip.index(), v) +} + +fn partial_value_from_load_constant<'a, H: HugrView>( + context: &context::DataflowContext<'a, H>, + node: Node, +) -> PV { + let load_op = context.hugr().get_optype(node).as_load_constant().unwrap(); + let const_node = context + .hugr() + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = context.hugr().get_optype(const_node).as_const().unwrap(); + context.value_handle(const_node, const_op.value()).into() } -#[derive(Debug,Clone,Copy,PartialEq,Eq,Hash)] -enum IO { Input, Output } +fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { + PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() +} -fn value_inputs<'a,H: HugrView>(context: &Ctx<'a, H>, n: Node) -> impl Iterator + 'a { +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum IO { + Input, + Output, +} + +fn value_inputs<'a, H: HugrView>( + context: &Ctx<'a, H>, + n: Node, +) -> impl Iterator + 'a { context.hugr().in_value_types(n).map(|x| x.0) } -fn value_outputs<'a,H: HugrView>(context: &Ctx<'a, H>, n: Node) -> impl Iterator + 'a { +fn value_outputs<'a, H: HugrView>( + context: &Ctx<'a, H>, + n: Node, +) -> impl Iterator + 'a { context.hugr().out_value_types(n).map(|x| x.0) } -fn tail_loop_worker<'b, 'a,H: HugrView>(context: &Ctx<'a, H>, n: Node, output_p: IncomingPort, control_variant: usize, v: &'b PartialValue) -> impl Iterator + 'b { +// todo this should work for dataflowblocks too +fn tail_loop_worker<'b, 'a, H: HugrView>( + context: &Ctx<'a, H>, + n: Node, + output_p: IncomingPort, + control_variant: usize, + v: &'b PV, +) -> impl Iterator + 'b { let tail_loop_op = context.get_optype(n).as_tail_loop().unwrap(); let num_variant_vals = if control_variant == 0 { tail_loop_op.just_inputs.len() @@ -365,9 +226,15 @@ fn tail_loop_worker<'b, 'a,H: HugrView>(context: &Ctx<'a, H>, n: Node, output_p: tail_loop_op.just_outputs.len() }; if output_p.index() == 0 { - Either::Left((0..num_variant_vals).map(move |i| (i.into(), v.variant_field_value(control_variant, i)))) + Either::Left( + (0..num_variant_vals) + .map(move |i| (i.into(), v.variant_field_value(control_variant, i))), + ) } else { - Either::Right(std::iter::once(((num_variant_vals + output_p.index()).into(), v.clone()))) + Either::Right(std::iter::once(( + (num_variant_vals + output_p.index()).into(), + v.clone(), + ))) } } @@ -377,9 +244,9 @@ ascent::ascent! { relation node(ArcCtx<'a, H>, Node); relation in_wire(ArcCtx<'a,H>, Node, IncomingPort); relation out_wire(ArcCtx<'a,H>, Node, OutgoingPort); - lattice out_wire_value(ArcCtx<'a,H>, Node, OutgoingPort, PartialValue); + lattice out_wire_value(ArcCtx<'a,H>, Node, OutgoingPort, PV); lattice node_in_value_row(ArcCtx<'a,H>, Node, ValueRow); - lattice in_wire_value(ArcCtx<'a,H>, Node, IncomingPort, PartialValue); + lattice in_wire_value(ArcCtx<'a,H>, Node, IncomingPort, PV); node(c, n) <-- context(c), for n in c.nodes(); @@ -389,7 +256,7 @@ ascent::ascent! { // All out wire values are initialised to Top. If any value is Top after // running we can infer that execution never reaches that value. - out_wire_value(c, n,p, PartialValue::Top) <-- out_wire(c, n,p); + out_wire_value(c, n,p, PV::top()) <-- out_wire(c, n,p); in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), if let Some((m,op)) = c.single_linked_output(*n, *ip), @@ -397,20 +264,20 @@ ascent::ascent! { node_in_value_row(c, n, top_row(c, *n)) <-- node(c, n); - node_in_value_row(c, n, singleton_in_row(c, n, p, v)) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); // LoadConstant relation load_constant_node(ArcCtx<'a, H>, Node); load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); - out_wire_value(c, n, 0.into(), PartialValue::from_load_constant(c, *n)) <-- + out_wire_value(c, n, 0.into(), partial_value_from_load_constant(c, *n)) <-- load_constant_node(c, n); // MakeTuple relation make_tuple_node(ArcCtx<'a,H>, Node); make_tuple_node(c, n) <-- node(c, n), if c.get_optype(*n).is_make_tuple(); - out_wire_value(c, n, 0.into(), PartialValue::tuple_from_value_row(vs)) <-- + out_wire_value(c, n, 0.into(), partial_value_tuple_from_value_row(vs.clone())) <-- make_tuple_node(c, n), node_in_value_row(c, n, vs); // UnpackTuple @@ -464,39 +331,54 @@ impl<'a, H: HugrView> Dataflow<'a, H> { Self::default() } - pub fn run_hugr(&mut self, hugr: &'a H) -> ArcCtx<'a,H> { + pub fn run_hugr(&mut self, hugr: &'a H) -> ArcCtx<'a, H> { let context = context::DataflowContext::new(hugr); self.context.push((context.clone(),)); self.run(); context } - pub fn read_out_wire_partial_value(&self, context: &Ctx<'a,H>, w: Wire) -> Option { - self.out_wire_value.iter().find_map(|(c,n,p,v)| (c.as_ref() == context && &w.node() == n && &w.source() == p).then_some(v.clone())) + pub fn read_out_wire_partial_value( + &self, + context: &Ctx<'a, H>, + w: Wire, + ) -> Option { + self.out_wire_value.iter().find_map(|(c, n, p, v)| { + (c.as_ref() == context && &w.node() == n && &w.source() == p).then(|| v.clone().0) + }) } - pub fn read_out_wire_value(&self, context: &Ctx<'a,H>, w: Wire) -> Option { + pub fn read_out_wire_value(&self, context: &Ctx<'a, H>, w: Wire) -> Option { // dbg!(&w); let pv = self.read_out_wire_partial_value(context, w)?; // dbg!(&pv); - let (_, typ) = context.hugr().out_value_types(w.node()).find(|(p,_)| *p == w.source()).unwrap(); + let (_, typ) = context + .hugr() + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); pv.try_into_value(&typ).ok() } } #[cfg(test)] mod test { - use hugr_core::{builder::{DFGBuilder, Dataflow, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, EMPTY_REG}, ops::{UnpackTuple, Value}, type_row, types::{FunctionType, SumType}}; - - use crate::const_fold2::datalog::PartialValue; + use hugr_core::{ + builder::{DFGBuilder, Dataflow, HugrBuilder, SubContainer}, + extension::{prelude::BOOL_T, EMPTY_REG}, + ops::{UnpackTuple, Value}, + type_row, + types::{FunctionType, SumType}, + }; + use hugr_core::extension::PartialValue; #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); - let v3 = builder.make_tuple([v1,v2]).unwrap(); + let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = super::Dataflow::new(); @@ -511,8 +393,11 @@ mod test { let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); - let v3 = builder.make_tuple([v1,v2]).unwrap(); - let [o1,o2] = builder.add_dataflow_op(UnpackTuple::new(type_row![BOOL_T,BOOL_T]), [v3]).unwrap().outputs_arr(); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let [o1, o2] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v3]) + .unwrap() + .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = super::Dataflow::new(); @@ -528,7 +413,10 @@ mod test { fn test_unpack_const() { let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); - let [o] = builder.add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]).unwrap().outputs_arr(); + let [o] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) + .unwrap() + .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = super::Dataflow::new(); @@ -541,9 +429,18 @@ mod test { #[test] fn test_tail_loop_never_iterates() { let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); - let r_v = Value::unit_sum(3,6).unwrap(); - let r_w = builder.add_load_value(Value::sum(1, [r_v.clone()], SumType::new([type_row![], r_v.get_type().into()])).unwrap()); - let tlb = builder.tail_loop_builder([], [], vec![r_v.get_type()].into()).unwrap(); + let r_v = Value::unit_sum(3, 6).unwrap(); + let r_w = builder.add_load_value( + Value::sum( + 1, + [r_v.clone()], + SumType::new([type_row![], r_v.get_type().into()]), + ) + .unwrap(), + ); + let tlb = builder + .tail_loop_builder([], [], vec![r_v.get_type()].into()) + .unwrap(); let [tl_o] = tlb.finish_with_outputs(r_w, []).unwrap().outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); @@ -559,8 +456,11 @@ mod test { #[test] fn test_tail_loop_always_iterates() { let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); - let r_w = builder.add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); - let tlb = builder.tail_loop_builder([], [], vec![BOOL_T].into()).unwrap(); + let r_w = builder + .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let tlb = builder + .tail_loop_builder([], [], vec![BOOL_T].into()) + .unwrap(); let [tl_o] = tlb.finish_with_outputs(r_w, []).unwrap().outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index d633d1e25..461d84e2a 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -10,59 +10,9 @@ use ascent::Lattice; use either::Either; use hugr_core::ops::{OpTag, OpTrait, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use hugr_core::extension::ValueHandle; -#[derive(Clone, Debug)] -pub struct ValueHandle(Vec, Node, Arc); -impl ValueHandle { - pub fn value(&self) -> &Value { - self.2.as_ref() - } - - pub fn tag(&self) -> usize { - match self.value() { - Value::Sum { tag, .. } => *tag, - Value::Tuple { .. } => 0, - _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value"), - } - } - - pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { - let vs = match self.value() { - Value::Sum { values, .. } => values, - Value::Tuple { vs, .. } => vs, - _ => panic!("ValueHandle::index called on non-Sum, non-Tuple value"), - }; - assert!(i < vs.len()); - let v = vs[i].clone().into(); - let mut is = self.0.clone(); - is.push(i); - Self(is, self.1, v) - } -} - -impl PartialEq for ValueHandle { - fn eq(&self, other: &Self) -> bool { - (&self.0, self.1) == (&other.0, other.1) - } -} - -impl Eq for ValueHandle {} - -impl Hash for ValueHandle { - fn hash(&self, state: &mut I) { - self.0.hash(state); - self.1.hash(state); - } -} - -impl Deref for ValueHandle { - type Target = Value; - - fn deref(&self) -> &Self::Target { - self.value() - } -} #[derive(Clone)] pub struct ValueCache(HashMap>); @@ -74,7 +24,7 @@ impl ValueCache { fn get(&mut self, node: Node, value: &Value) -> ValueHandle { let v = self.0.entry(node).or_insert_with(|| value.clone().into()).clone(); - ValueHandle(vec![], node, v) + ValueHandle::new(node, v) } } From 31e19e143cde5ca4eec6cd47a1b2568a9067e840 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Jun 2024 12:31:26 +0100 Subject: [PATCH 05/12] proptest1 on PartialValue --- hugr-core/src/extension.rs | 2 +- hugr-core/src/extension/const_fold.rs | 2 - .../src/extension/const_fold/partial_value.rs | 292 -------- hugr-core/src/lib.rs | 1 + hugr-core/src/ops/constant.rs | 40 +- hugr-core/src/partial_value.rs | 705 ++++++++++++++++++ hugr-passes/src/const_fold2/datalog.rs | 104 ++- .../src/const_fold2/datalog/context.rs | 23 +- 8 files changed, 838 insertions(+), 331 deletions(-) delete mode 100644 hugr-core/src/extension/const_fold/partial_value.rs create mode 100644 hugr-core/src/partial_value.rs diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 90ffba912..1ef9c50cd 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -36,7 +36,7 @@ mod const_fold; pub mod prelude; pub mod simple_op; pub mod validate; -pub use const_fold::{ConstFold, ConstFoldResult, Folder, PartialValue, ValueHandle, HashableHashMap}; +pub use const_fold::{ConstFold, ConstFoldResult, Folder}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; pub mod declarative; diff --git a/hugr-core/src/extension/const_fold.rs b/hugr-core/src/extension/const_fold.rs index 4485a05b0..a3aae93eb 100644 --- a/hugr-core/src/extension/const_fold.rs +++ b/hugr-core/src/extension/const_fold.rs @@ -10,8 +10,6 @@ use crate::OutgoingPort; use crate::ops; -mod partial_value; -pub use partial_value::{PartialValue, ValueHandle, HashableHashMap}; /// Output of constant folding an operation, None indicates folding was either /// not possible or unsuccessful. An empty vector indicates folding was /// successful and no values are output. diff --git a/hugr-core/src/extension/const_fold/partial_value.rs b/hugr-core/src/extension/const_fold/partial_value.rs deleted file mode 100644 index f7de384a4..000000000 --- a/hugr-core/src/extension/const_fold/partial_value.rs +++ /dev/null @@ -1,292 +0,0 @@ -use std::sync::Arc; -use std::hash::{Hash, Hasher}; -use std::ops::Deref; -use std::collections::HashMap; - -use itertools::{zip_eq, Itertools as _}; - -use crate::ops::{OpTag, OpTrait, Value}; -use crate::types::{Type, TypeEnum}; -use crate::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; - -#[derive(Clone, Debug)] -pub struct ValueHandle(Vec, Node, Arc); - -impl ValueHandle { - pub fn new(node: Node, value: Arc) -> Self { - Self(vec![], node, value) - } - - pub fn value(&self) -> &Value { - self.2.as_ref() - } - - pub fn tag(&self) -> usize { - match self.value() { - Value::Sum { tag, .. } => *tag, - Value::Tuple { .. } => 0, - _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value"), - } - } - - pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { - let vs = match self.value() { - Value::Sum { values, .. } => values, - Value::Tuple { vs, .. } => vs, - _ => panic!("ValueHandle::index called on non-Sum, non-Tuple value"), - }; - assert!(i < vs.len()); - let v = vs[i].clone().into(); - let mut is = self.0.clone(); - is.push(i); - Self(is, self.1, v) - } -} - -impl PartialEq for ValueHandle { - fn eq(&self, other: &Self) -> bool { - (&self.0, self.1) == (&other.0, other.1) - } -} - -impl Eq for ValueHandle {} - -impl Hash for ValueHandle { - fn hash(&self, state: &mut I) { - self.0.hash(state); - self.1.hash(state); - } -} - -/// TODO this is dodgy -/// we do not hash or compare the value, just the key -/// this means two handles with different keys, but with the same value, will -/// not compare equal. -impl Deref for ValueHandle { - type Target = Value; - - fn deref(&self) -> &Self::Target { - self.value() - } -} - -/// TODO shouldn't be pub -#[derive(PartialEq, Clone, Eq, Debug)] -pub struct HashableHashMap(HashMap); - -impl Hash for HashableHashMap { - fn hash(&self, state: &mut H) { - self.0.keys().for_each(|k| k.hash(state)); - self.0.values().for_each(|v| v.hash(state)); - } -} - -#[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { - Bottom, - Value(ValueHandle), - PartialSum(HashableHashMap>), - Top, -} - -impl From for PartialValue { - fn from(v: ValueHandle) -> Self { - match v.value() { - Value::Tuple { vs } => { - let vec = (0..vs.len()).map(|i| PartialValue::from(v.index(i)).into()).collect(); - Self::PartialSum(HashableHashMap([(0, vec)].into_iter().collect())) - } - Value::Sum { tag, values, .. } => { - let vec = (0..values.len()).map(|i| PartialValue::from(v.index(i)).into()).collect(); - Self::PartialSum(HashableHashMap([(*tag, vec)].into_iter().collect())) - } - _ => Self::Value(v) - } - } -} - -impl PartialValue { - const BOTTOM: Self = Self::Bottom; - const BOTTOM_REF: &'static Self = &Self::BOTTOM; - - fn initialised(&self) -> bool { - !self.is_top() - } - - fn is_top(&self) -> bool { self == &PartialValue::Top } - - - /// TODO docs - /// just delegate to variant_field_value - pub fn tuple_field_value(&self, idx: usize) -> Self { - self.variant_field_value(0,idx) - } - - /// TODO docs - pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { - match self { - Self::Bottom => Self::Bottom, - Self::PartialSum(HashableHashMap(hm)) => { - if let Some(row) = hm.get(&variant) { - assert!(row.len() > idx); - row[idx].clone() - } else { - // We must return top. if self were to gain this variant, we would return the element of that variant. - // We must ensure that the value return now is <= that future value - Self::Top - } - }, - Self::Value(v) if v.tag() == variant => { - Self::Value(v.index(idx)) - }, - _ => Self::Top - } - } - - pub fn try_into_value(self, typ: &Type) -> Result { - let r = match self { - Self::Value(v) => v.value().clone(), - Self::PartialSum(HashableHashMap(hm)) => { - let err = |hm| Err(Self::PartialSum(HashableHashMap(hm))); - let Ok((k,v)) = hm.iter().exactly_one() else { - return err(hm); - }; - let TypeEnum::Sum(st) = typ.as_type_enum() else { - return err(hm); - }; - let Some(r) = st.get_variant(*k) else { - return err(hm); - }; - if v.len() != r.len() { - return err(hm); - } - - let Ok(vs) = zip_eq(v.into_iter(), r.into_iter()).map(|(v,t)| v.clone().try_into_value(t)).collect::, _>>() else { - return err(hm); - }; - - Value::sum(*k, vs, st.clone()).map_err(|_| Self::PartialSum(HashableHashMap(hm)))? - }, - x => Err(x)? - - }; - assert_eq!(typ, &r.get_type()); - Ok(r) - } - - fn join_value_handle(&mut self, vh: ValueHandle) -> bool { - let mut new_self = self; - match &mut new_self { - Self::Bottom => { - false - }, - s@Self::Value(_) => { - let Self::Value(v) = *s else { unreachable!() }; - if v == &vh { - false - } else { - **s = Self::Bottom; - true - } - }, - s@Self::PartialSum(_) => { - match vh.into() { - Self::Value(_) => { - **s = Self::Bottom; - true - } - other => s.join_mut(other) - } - }, - s@Self::Top => { - **s = vh.into(); - true - } - } - } - - pub fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - pub fn join_mut(&mut self, other: Self) -> bool { - // println!("join {self:?}\n{:?}", &other); - let mut s = self; - let changed = match (&mut s, other) { - (Self::Bottom, _) => false, - (s, other@Self::Bottom) => { - **s = other; - true - }, - (_, Self::Top) => false, - (s@Self::Top, other) => { - **s = other; - true - } - (Self::Value(h1), Self::Value(h2)) if h1 == &h2 || h1.value() == h2.value() => { - false - } - (s@Self::PartialSum(_), Self::PartialSum(HashableHashMap(hm2))) => { - let mut changed = false; - let Self::PartialSum(HashableHashMap(hm1)) = *s else { unreachable!() }; - for (k, v) in hm2 { - if let Some(row) = hm1.get_mut(&k) { - for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { - changed |= lhs.join_mut(rhs); - } - } else { - hm1.insert(k, v); - changed = true; - } - } - changed - } - (s@Self::Value(_), other@Self::PartialSum(_)) => { - let mut old_self = other; - std::mem::swap(*s, &mut old_self); - let Self::Value(h) = old_self else { unreachable!() }; - s.join_value_handle(h) - } - (s@Self::PartialSum(_), Self::Value(h)) => { - s.join_value_handle(h) - } - (s,_) => { - **s = Self::Bottom; - false - } - }; - // if changed { - // println!("join new self: {:?}", s); - // } - changed - } - - pub fn meet(mut self, other: Self) -> Self { - self.meet_mut(other) ; - self - } - - pub fn meet_mut(&mut self, _other: Self) -> bool { - todo!() - } - - pub fn top() -> Self { - Self::Top - } - - pub fn bottom() -> Self { - Self::Bottom - } - - pub fn variant(tag: usize, values: impl IntoIterator) -> Self { - Self::PartialSum(HashableHashMap([(tag, values.into_iter().collect())].into_iter().collect())) - } -} - -impl PartialOrd for PartialValue { - fn partial_cmp(&self, other: &Self) -> Option { - // TODO we can do better - (self == other).then_some(std::cmp::Ordering::Equal) - } -} diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index 6bd2a262d..318bd02fc 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -14,6 +14,7 @@ pub mod extension; pub mod hugr; pub mod macros; pub mod ops; +pub mod partial_value; pub mod std_extensions; pub mod types; pub mod utils; diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index f72e95a6d..486c910e9 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -138,18 +138,42 @@ pub enum Value { impl PartialEq for Value { fn eq(&self, other: &Self) -> bool { - let cmp_tuple_sum = |tuple_values: &[Value], tag: usize, sum_values: &[Value], sum_type: &SumType| { - tag == 0 && sum_type.num_variants() == 1 && tuple_values == sum_values - }; + let cmp_tuple_sum = + |tuple_values: &[Value], tag: usize, sum_values: &[Value], sum_type: &SumType| { + tag == 0 && sum_type.num_variants() == 1 && tuple_values == sum_values + }; match (self, other) { (Self::Extension { e: e1 }, Self::Extension { e: e2 }) => e1 == e2, (Self::Function { hugr: h1 }, Self::Function { hugr: h2 }) => h1 == h2, (Self::Tuple { vs: v1 }, Self::Tuple { vs: v2 }) => v1 == v2, - (Self::Sum { tag: t1, values: v1, sum_type: s1 }, Self::Sum { tag: t2, values: v2, sum_type: s2 }) => { - t1 == t2 && v1 == v2 && s1 == s2 - } - (Self::Tuple { vs }, Self::Sum {tag, values, sum_type}) | - (Self::Sum {tag, values, sum_type}, Self::Tuple { vs }) => cmp_tuple_sum(vs, *tag, values, sum_type), + ( + Self::Sum { + tag: t1, + values: v1, + sum_type: s1, + }, + Self::Sum { + tag: t2, + values: v2, + sum_type: s2, + }, + ) => t1 == t2 && v1 == v2 && s1 == s2, + ( + Self::Tuple { vs }, + Self::Sum { + tag, + values, + sum_type, + }, + ) + | ( + Self::Sum { + tag, + values, + sum_type, + }, + Self::Tuple { vs }, + ) => cmp_tuple_sum(vs, *tag, values, sum_type), _ => false, } } diff --git a/hugr-core/src/partial_value.rs b/hugr-core/src/partial_value.rs new file mode 100644 index 000000000..b348eeed6 --- /dev/null +++ b/hugr-core/src/partial_value.rs @@ -0,0 +1,705 @@ +#![allow(missing_docs)] +use std::any::Any; +use std::collections::HashMap; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +use downcast_rs::Downcast; +use itertools::{zip_eq, Either, Itertools as _}; + +use crate::ops::{OpTag, OpTrait, Value}; +use crate::types::{Type, TypeEnum}; +use crate::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +pub trait ValueName: std::fmt::Debug + Downcast + Any { + fn hash(&self) -> u64; + fn eq(&self, other: &dyn ValueName) -> bool; +} + +#[derive(Clone, Debug)] +pub struct ValueKey(Vec, Either>); + +impl PartialEq for ValueKey { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + && match (&self.1, &other.1) { + (Either::Left(ref n1), Either::Left(ref n2)) => n1 == n2, + (Either::Right(ref v1), Either::Right(ref v2)) => v1.eq(v2.as_ref()), + _ => false, + } + } +} + +impl Eq for ValueKey {} + +impl Hash for ValueKey { + fn hash(&self, state: &mut H) { + self.0.hash(state); + match &self.1 { + Either::Left(n) => n.hash(state), + Either::Right(v) => state.write_u64(v.hash()), + } + } +} + +impl ValueName for String { + fn hash(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + ::hash(self, &mut hasher); + hasher.finish() + } + + fn eq(&self, other: &dyn ValueName) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self == other + } else { + false + } + } +} + +impl From for ValueKey { + fn from(n: Node) -> Self { + Self(vec![], Either::Left(n)) + } +} + +#[derive(Clone, Debug)] +pub struct ValueHandle(ValueKey, Arc); + +impl ValueHandle { + pub fn new(key: ValueKey, value: Arc) -> Self { + Self(key, value) + } + + pub fn value(&self) -> &Value { + self.1.as_ref() + } + + pub fn tag(&self) -> usize { + match self.value() { + Value::Sum { tag, .. } => *tag, + Value::Tuple { .. } => 0, + _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value"), + } + } + + pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { + let vs = match self.value() { + Value::Sum { values, .. } => values, + Value::Tuple { vs, .. } => vs, + _ => panic!("ValueHandle::index called on non-Sum, non-Tuple value"), + }; + assert!(i < vs.len()); + let v = vs[i].clone().into(); + let mut is = self.0 .0.clone(); + is.push(i); + Self(ValueKey(is, self.0 .1.clone()), v) + } +} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + self.0.hash(state); + } +} + +/// TODO this is perhaps dodgy +/// we do not hash or compare the value, just the key +/// this means two handles with different keys, but with the same value, will +/// not compare equal. +impl Deref for ValueHandle { + type Target = Value; + + fn deref(&self) -> &Self::Target { + self.value() + } +} + +/// TODO shouldn't be pub +#[derive(PartialEq, Clone, Eq)] +pub struct HashableHashMap(HashMap); + +impl std::fmt::Debug + for HashableHashMap +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Hash for HashableHashMap { + fn hash(&self, state: &mut H) { + self.0.keys().for_each(|k| k.hash(state)); + self.0.values().for_each(|v| v.hash(state)); + } +} + +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + Bottom, + Value(ValueHandle), + PartialSum(HashableHashMap>), + Top, +} + +impl From for PartialValue { + fn from(v: ValueHandle) -> Self { + match v.value() { + Value::Tuple { vs } => { + let vec = (0..vs.len()) + .map(|i| PartialValue::from(v.index(i)).into()) + .collect(); + Self::PartialSum(HashableHashMap([(0, vec)].into_iter().collect())) + } + Value::Sum { tag, values, .. } => { + let vec = (0..values.len()) + .map(|i| PartialValue::from(v.index(i)).into()) + .collect(); + Self::PartialSum(HashableHashMap([(*tag, vec)].into_iter().collect())) + } + _ => Self::Value(v), + } + } +} + +impl PartialValue { + const BOTTOM: Self = Self::Bottom; + const BOTTOM_REF: &'static Self = &Self::BOTTOM; + + fn initialised(&self) -> bool { + !self.is_top() + } + + fn is_top(&self) -> bool { + self == &PartialValue::Top + } + + fn assert_invariants(&self) { + match self { + Self::PartialSum(HashableHashMap(hm)) => { + assert_ne!(hm.len(), 0); + for pv in hm.values().flat_map(|x| x.iter()) { + pv.assert_invariants(); + } + } + Self::Value(v) => { + assert!(matches!(v.clone().into(), Self::Value(_))) + } + _ => {} + } + } + + /// TODO docs + /// just delegate to variant_field_value + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO docs + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + match self { + Self::Bottom => Self::Bottom, + Self::PartialSum(HashableHashMap(hm)) => { + if let Some(row) = hm.get(&variant) { + assert!(row.len() > idx); + row[idx].clone() + } else { + // We must return top. if self were to gain this variant, we would return the element of that variant. + // We must ensure that the value return now is <= that future value + Self::Top + } + } + Self::Value(v) if v.tag() == variant => Self::Value(v.index(idx)), + _ => Self::Top, + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let r = match self { + Self::Value(v) => v.value().clone(), + Self::PartialSum(HashableHashMap(hm)) => { + let err = |hm| Err(Self::PartialSum(HashableHashMap(hm))); + let Ok((k, v)) = hm.iter().exactly_one() else { + return err(hm); + }; + let TypeEnum::Sum(st) = typ.as_type_enum() else { + return err(hm); + }; + let Some(r) = st.get_variant(*k) else { + return err(hm); + }; + if v.len() != r.len() { + return err(hm); + } + + let Ok(vs) = zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + else { + return err(hm); + }; + + Value::sum(*k, vs, st.clone()).map_err(|_| Self::PartialSum(HashableHashMap(hm)))? + } + x => Err(x)?, + }; + assert_eq!(typ, &r.get_type()); + Ok(r) + } + + fn join_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + self.assert_invariants(); + let mut new_self = self; + match &mut new_self { + Self::Top => false, + new_self @ Self::Value(_) => { + let Self::Value(v) = *new_self else { + unreachable!() + }; + if v == &vh { + false + } else { + **new_self = Self::Top; + true + } + } + s @ Self::PartialSum(_) => match vh.into() { + Self::Value(_) => { + **s = Self::Top; + true + } + other => s.join_mut(other), + }, + new_self @ Self::Bottom => { + **new_self = vh.into(); + true + } + } + } + + fn meet_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + self.assert_invariants(); + let mut new_self = self; + match &mut new_self { + Self::Bottom => false, + new_self @ Self::Value(_) => { + let Self::Value(v) = *new_self else { + unreachable!() + }; + if v == &vh { + false + } else { + **new_self = Self::Bottom; + true + } + } + new_self @ Self::PartialSum(_) => match vh.into() { + Self::Value(_) => { + **new_self = Self::Bottom; + true + } + other => new_self.join_mut(other), + }, + new_self @ Self::Top => { + **new_self = vh.into(); + true + } + } + } + + pub fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + pub fn join_mut(&mut self, other: Self) -> bool { + // println!("join {self:?}\n{:?}", &other); + let mut new_self = self; + let changed = match (&mut new_self, other) { + (Self::Top, _) => false, + (new_self, other @ Self::Top) => { + **new_self = other; + true + } + (_, Self::Bottom) => false, + (new_self @ Self::Bottom, other) => { + **new_self = other; + true + } + (Self::Value(h1), Self::Value(h2)) if h1 == &h2 || h1.value() == h2.value() => false, + (new_self @ Self::PartialSum(_), Self::PartialSum(HashableHashMap(hm2))) => { + let mut changed = false; + let Self::PartialSum(HashableHashMap(hm1)) = *new_self else { + unreachable!() + }; + for (k, v) in hm2 { + if let Some(row) = hm1.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + hm1.insert(k, v); + changed = true; + } + } + changed + } + (new_self @ Self::Value(_), other @ Self::PartialSum(_)) => { + let mut old_self = other; + std::mem::swap(*new_self, &mut old_self); + let Self::Value(h) = old_self else { + unreachable!() + }; + new_self.join_mut_value_handle(h) + } + (new_self @ Self::PartialSum(_), Self::Value(h)) => new_self.join_mut_value_handle(h), + (new_self, _) => { + **new_self = Self::Top; + false + } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + pub fn meet_mut(&mut self, other: Self) -> bool { + let mut new_self = self; + let changed = match (&mut new_self, other) { + (Self::Bottom, _) => false, + (new_self, other @ Self::Bottom) => { + **new_self = other; + true + } + (_, Self::Top) => false, + (new_self @ Self::Top, other) => { + **new_self = other; + true + } + (Self::Value(h1), Self::Value(h2)) if h1 == &h2 || h1.value() == h2.value() => false, + (new_self @ Self::PartialSum(_), Self::PartialSum(HashableHashMap(hm2))) => { + let mut changed = false; + let Self::PartialSum(HashableHashMap(hm1)) = *new_self else { + unreachable!() + }; + let mut keys_to_remove = vec![]; + for k in hm1.keys() { + if !hm2.contains_key(k) { + keys_to_remove.push(*k); + } + } + for (k, v) in hm2 { + if let Some(row) = hm1.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.meet_mut(rhs); + } + } else { + keys_to_remove.push(k); + } + } + for k in keys_to_remove { + hm1.remove(&k); + changed = true; + } + changed + } + (new_self @ Self::Value(_), other @ Self::PartialSum(_)) => { + let mut old_self = other; + std::mem::swap(*new_self, &mut old_self); + let Self::Value(h) = old_self else { + unreachable!() + }; + new_self.meet_mut_value_handle(h) + } + (s @ Self::PartialSum(_), Self::Value(h)) => s.meet_mut_value_handle(h), + (new_self, _) => { + **new_self = Self::Bottom; + false + } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn top() -> Self { + Self::Top + } + + pub fn bottom() -> Self { + Self::Bottom + } + + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + Self::PartialSum(HashableHashMap( + [(tag, values.into_iter().collect())].into_iter().collect(), + )) + } + + pub fn unit() -> Self { + Self::variant(0, []) + } +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::PartialSum(HashableHashMap(hm1)), Self::PartialSum(HashableHashMap(hm2))) => { + let max_key = hm1.keys().chain(hm2.keys()).copied().max().unwrap(); + let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); + for k in hm1.keys() { + keys1[*k] = 1; + } + + for k in hm2.keys() { + keys2[*k] = 1; + } + + if let Some(ord) = keys1.partial_cmp(&keys2) { + if ord != Ordering::Equal { + return Some(ord); + } + } else { + return None; + } + for (k, lhs) in hm1 { + let Some(rhs) = hm2.get(&k) else { + unreachable!() + }; + match lhs.partial_cmp(rhs) { + Some(Ordering::Equal) => continue, + x => { + return x; + } + } + } + Some(Ordering::Equal) + } + _ => None, + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use itertools::Itertools as _; + use lazy_static::lazy_static; + use proptest::prelude::*; + + use super::{PartialValue, ValueHandle}; + impl Arbitrary for ValueHandle { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + // prop_oneof![ + + // ] + todo!() + } + } + + #[derive(Debug, PartialEq, Eq, Clone)] + struct UnarySumType(usize, Vec>>); + + lazy_static! { + static ref UNARY_SUM_TYPE_LEAF: UnarySumType = UnarySumType::new([]); + } + + impl UnarySumType { + pub fn new(vs: impl IntoIterator>>) -> Self { + let vec = vs.into_iter().collect_vec(); + let depth: usize = vec + .iter() + .flat_map(|x| x.iter()) + .map(|x| x.0 + 1) + .max() + .unwrap_or(0); + Self(depth, vec.into()).into() + } + + fn is_leaf(&self) -> bool { + self.0 == 0 + } + + fn assert_invariants(&self) { + if self.is_leaf() { + assert!(self.1.iter().all(Vec::is_empty)); + } else { + for v in self.1.iter().flat_map(|x| x.iter()) { + assert!(v.0 < self.0); + v.assert_invariants() + } + } + } + + fn select(self) -> impl Strategy>)>> { + if self.is_leaf() { + Just(None).boxed() + } else { + any::() + .prop_map(move |i| { + let index = i.index(self.1.len()); + Some((index, self.1[index].clone())) + }) + .boxed() + } + } + } + + #[derive(Clone, PartialEq, Eq, Debug)] + struct UnarySumTypeParams { + depth: usize, + branch_width: usize, + } + + impl UnarySumTypeParams { + pub fn descend(mut self, d: usize) -> Self { + assert!(d < self.depth); + self.depth = d; + self + } + } + + impl Default for UnarySumTypeParams { + fn default() -> Self { + Self { + depth: 3, + branch_width: 3, + } + } + } + + impl Arbitrary for UnarySumType { + type Parameters = UnarySumTypeParams; + type Strategy = BoxedStrategy; + fn arbitrary_with( + params @ UnarySumTypeParams { + depth, + branch_width, + }: Self::Parameters, + ) -> Self::Strategy { + if depth == 0 { + Just(UNARY_SUM_TYPE_LEAF.clone()).boxed() + } else { + (0..depth) + .prop_flat_map(move |d| { + prop::collection::vec( + prop::collection::vec( + any_with::(params.clone().descend(d)).prop_map_into(), + 0..branch_width, + ), + 1..=branch_width, + ) + .prop_map(UnarySumType::new) + }) + .boxed() + } + } + } + + proptest! { + #[test] + fn unary_sum_type_valid(ust: UnarySumType) { + ust.assert_invariants(); + } + } + + fn any_partial_value_of_type(ust: UnarySumType) -> impl Strategy { + ust.select().prop_flat_map(|x| { + if let Some((index, usts)) = x { + let pvs = usts + .into_iter() + .map(|x| any_partial_value_of_type(Arc::::unwrap_or_clone(x))) + .collect_vec(); + pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + .boxed() + } else { + Just(PartialValue::unit()).boxed() + } + }) + } + + fn any_partial_value_with( + params: ::Parameters, + ) -> impl Strategy { + any_with::(params).prop_flat_map(any_partial_value_of_type) + } + + fn any_partial_value() -> impl Strategy { + any_partial_value_with(Default::default()) + } + + fn any_partial_values() -> impl Strategy { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(ust.clone())) + .collect_vec(), + ) + .unwrap() + }) + } + + proptest! { + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(&v <= &PartialValue::Top); + prop_assert!(&v >= &PartialValue::Bottom); + } + + #[test] + fn lattice_changed(v1 in any_partial_value()) { + let mut subject = v1.clone(); + assert!(!subject.join_mut(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(&meet <= &v1, "meet not less <=: {:#?}", &meet); + prop_assert!(&meet <= &v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(&join >= &v1, "join not >=: {:#?}", &join); + prop_assert!(&join >= &v2, "join not >=: {:#?}", &join); + } + } +} diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index b497d91ee..06d5ea253 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,15 +1,14 @@ -use ascent::lattice::BoundedLattice; +use ascent::lattice::{BoundedLattice, Dual, Lattice}; use delegate::delegate; use itertools::{zip_eq, Itertools}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use ascent::{ascent_run, Lattice}; use either::Either; -use hugr_core::extension::{HashableHashMap, PartialValue, ValueHandle}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::partial_value::{PartialValue, ValueHandle}; use hugr_core::types::{FunctionType, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; @@ -76,12 +75,8 @@ impl BoundedLattice for PV { struct ValueRow(Vec); impl ValueRow { - // fn into_partial_value(self) -> PartialValue { - // todo!() - // } - fn new(len: usize) -> Self { - Self(vec![PV::top(); len]) + Self(vec![PV::bottom(); len]) } fn singleton(len: usize, idx: usize, v: PV) -> Self { @@ -113,8 +108,9 @@ impl ValueRow { } impl Lattice for ValueRow { - fn meet(self, _other: Self) -> Self { - todo!() + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self } fn join(mut self, other: Self) -> Self { @@ -130,6 +126,15 @@ impl Lattice for ValueRow { } changed } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } } impl IntoIterator for ValueRow { @@ -142,8 +147,8 @@ impl IntoIterator for ValueRow { } } -type ArcCtx<'a, H: HugrView> = Arc>; -type Ctx<'a, H: HugrView> = DataflowContext<'a, H>; +type Ctx<'a, H> = DataflowContext<'a, H>; +type ArcCtx<'a, H> = Arc>; fn top_row<'a, H: HugrView>(context: &Ctx<'a, H>, n: Node) -> ValueRow { if let Some(sig) = context.hugr().signature(n) { @@ -184,7 +189,9 @@ fn partial_value_from_load_constant<'a, H: HugrView>( .unwrap() .0; let const_op = context.hugr().get_optype(const_node).as_const().unwrap(); - context.value_handle(const_node, const_op.value()).into() + context + .node_value_handle(const_node, const_op.value()) + .into() } fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { @@ -238,6 +245,67 @@ fn tail_loop_worker<'b, 'a, H: HugrView>( } } +#[derive(PartialEq, Eq, PartialOrd, Hash, Debug, Clone)] +pub enum TailLoopTermination { + SingleIteration, + Unknown, + NeverTerminates, +} + +impl Lattice for TailLoopTermination { + fn meet_mut(&mut self, other: Self) -> bool { + match (self, other) { + (Self::SingleIteration, _) => false, + (s, o @ Self::SingleIteration) => { + *s = o; + true + } + (Self::Unknown, _) => false, + (s, o @ Self::Unknown) => { + *s = o; + true + } + _ => false, + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match (self, other) { + (Self::NeverTerminates, _) => false, + (s, o @ Self::NeverTerminates) => { + *s = o; + true + } + (Self::Unknown, _) => false, + (s, o @ Self::Unknown) => { + *s = o; + true + } + _ => false, + } + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } +} + +impl BoundedLattice for TailLoopTermination { + fn bottom() -> Self { + Self::NeverTerminates + } + + fn top() -> Self { + Self::SingleIteration + } +} + ascent::ascent! { struct Dataflow<'a, H: HugrView>; relation context(ArcCtx<'a, H>); @@ -254,9 +322,9 @@ ascent::ascent! { out_wire(c, n,p) <-- node(c, n), for p in value_outputs(c, *n); - // All out wire values are initialised to Top. If any value is Top after + // All out wire values are initialised to Bottom. If any value is Bottom after // running we can infer that execution never reaches that value. - out_wire_value(c, n,p, PV::top()) <-- out_wire(c, n,p); + out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), if let Some((m,op)) = c.single_linked_output(*n, *ip), @@ -307,6 +375,9 @@ ascent::ascent! { tail_loop_io_node(c,tl,n, io) <-- tail_loop_node(c,tl), if let Some([i,o]) = c.get_io(*tl), for (n,io) in [(i,IO::Input), (o, IO::Output)]; + lattice tail_loop_termination(ArcCtx<'a,H>,Node,Dual); + tail_loop_termination(c,n,Dual::top()) <-- tail_loop_node(c,n); + // inputs of tail loop propagate to Input node of child region out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- @@ -323,7 +394,6 @@ ascent::ascent! { in_wire_value(c, o, output_p, output_v), for (p, v) in tail_loop_worker(c, *tl, *output_p, 1, output_v); - } impl<'a, H: HugrView> Dataflow<'a, H> { @@ -371,7 +441,7 @@ mod test { types::{FunctionType, SumType}, }; - use hugr_core::extension::PartialValue; + use hugr_core::partial_value::PartialValue; #[test] fn test_make_tuple() { diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 461d84e2a..e01b7d515 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -9,26 +9,27 @@ use ascent::Lattice; use either::Either; use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::partial_value::{ValueHandle, ValueKey}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; -use hugr_core::extension::ValueHandle; - - #[derive(Clone)] -pub struct ValueCache(HashMap>); +pub struct ValueCache(HashMap>); impl ValueCache { fn new() -> Self { Self(HashMap::new()) } - fn get(&mut self, node: Node, value: &Value) -> ValueHandle { - let v = self.0.entry(node).or_insert_with(|| value.clone().into()).clone(); - ValueHandle::new(node, v) + fn get(&mut self, key: ValueKey, value: &Value) -> ValueHandle { + let v = self + .0 + .entry(key.clone()) + .or_insert_with(|| value.clone().into()) + .clone(); + ValueHandle::new(key, v) } } - static mut CONTEXT_ID: AtomicUsize = AtomicUsize::new(0); fn next_context_id() -> usize { @@ -50,8 +51,8 @@ impl<'a, H> DataflowContext<'a, H> { }) } - pub fn value_handle(&self, node: Node, value: &Value) -> ValueHandle { - self.cache.borrow_mut().get(node, value) + pub fn node_value_handle(&self, node: Node, value: &Value) -> ValueHandle { + self.cache.borrow_mut().get(node.into(), value) } pub fn hugr(&self) -> &'a H { @@ -95,7 +96,7 @@ impl PartialOrd for DataflowContext<'_, H> { } } -impl<'a,H> Deref for DataflowContext<'a,H> { +impl<'a, H> Deref for DataflowContext<'a, H> { type Target = H; fn deref(&self) -> &Self::Target { From 5123943e4d72fcf10dd27c24587de893943f4ef4 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Jun 2024 15:00:53 +0100 Subject: [PATCH 06/12] refactoring --- hugr-core/src/partial_value.rs | 593 ++++++++++-------- hugr-core/src/partial_value/value_handle.rs | 225 +++++++ .../std_extensions/arithmetic/int_types.rs | 2 +- hugr-passes/src/const_fold2/datalog.rs | 1 - .../src/const_fold2/datalog/context.rs | 6 +- 5 files changed, 545 insertions(+), 282 deletions(-) create mode 100644 hugr-core/src/partial_value/value_handle.rs diff --git a/hugr-core/src/partial_value.rs b/hugr-core/src/partial_value.rs index b348eeed6..ed8e9adac 100644 --- a/hugr-core/src/partial_value.rs +++ b/hugr-core/src/partial_value.rs @@ -1,177 +1,209 @@ #![allow(missing_docs)] -use std::any::Any; +use std::cmp::Ordering; use std::collections::HashMap; -use std::hash::{DefaultHasher, Hash, Hasher}; -use std::ops::Deref; -use std::sync::Arc; +use std::hash::{Hash, Hasher}; -use downcast_rs::Downcast; -use itertools::{zip_eq, Either, Itertools as _}; +use itertools::{zip_eq, Itertools as _}; -use crate::ops::{OpTag, OpTrait, Value}; +use crate::ops::Value; use crate::types::{Type, TypeEnum}; -use crate::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; -pub trait ValueName: std::fmt::Debug + Downcast + Any { - fn hash(&self) -> u64; - fn eq(&self, other: &dyn ValueName) -> bool; -} +mod value_handle; -#[derive(Clone, Debug)] -pub struct ValueKey(Vec, Either>); +pub use value_handle::{ValueKey, ValueHandle}; -impl PartialEq for ValueKey { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - && match (&self.1, &other.1) { - (Either::Left(ref n1), Either::Left(ref n2)) => n1 == n2, - (Either::Right(ref v1), Either::Right(ref v2)) => v1.eq(v2.as_ref()), - _ => false, - } - } -} -impl Eq for ValueKey {} +/// TODO shouldn't be pub +#[derive(PartialEq, Clone, Eq)] +pub struct PartialSum(HashMap>); -impl Hash for ValueKey { - fn hash(&self, state: &mut H) { - self.0.hash(state); - match &self.1 { - Either::Left(n) => n.hash(state), - Either::Right(v) => state.write_u64(v.hash()), - } +impl PartialSum { + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + Self([(tag, values.into_iter().collect())].into_iter().collect()) } -} -impl ValueName for String { - fn hash(&self) -> u64 { - let mut hasher = DefaultHasher::new(); - ::hash(self, &mut hasher); - hasher.finish() + pub fn num_variants(&self) -> usize { + self.0.len() } - fn eq(&self, other: &dyn ValueName) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self == other - } else { - false + fn assert_variants(&self) { + assert_ne!(self.num_variants(), 0); + for pv in self.0.values().flat_map(|x| x.iter()) { + pv.assert_invariants(); } } -} -impl From for ValueKey { - fn from(n: Node) -> Self { - Self(vec![], Either::Left(n)) + pub fn variant_field_value(&self, variant: usize, idx: usize) -> PartialValue { + if let Some(row) = self.0.get(&variant) { + assert!(row.len() > idx); + row[idx].clone() + } else { + // We must return top. if self were to gain this variant, we would return the element of that variant. + // We must ensure that the value return now is <= that future value + PartialValue::top() + } } -} -#[derive(Clone, Debug)] -pub struct ValueHandle(ValueKey, Arc); + pub fn try_into_value(self, typ: &Type) -> Result { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; -impl ValueHandle { - pub fn new(key: ValueKey, value: Arc) -> Self { - Self(key, value) + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self) + } + match zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::,_>>() { + Ok(vs) => { + Value::sum(*k, vs, st.clone()).map_err(|_| self) + } + Err(_) => Err(self) + } } - pub fn value(&self) -> &Value { - self.1.as_ref() - } + // unsafe because we panic if any common rows have different lengths + fn join_mut_unsafe(&mut self, other: Self) -> bool { + let mut changed = false; - pub fn tag(&self) -> usize { - match self.value() { - Value::Sum { tag, .. } => *tag, - Value::Tuple { .. } => 0, - _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value"), + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + self.0.insert(k, v); + changed = true; + } } + changed } - pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { - let vs = match self.value() { - Value::Sum { values, .. } => values, - Value::Tuple { vs, .. } => vs, - _ => panic!("ValueHandle::index called on non-Sum, non-Tuple value"), - }; - assert!(i < vs.len()); - let v = vs[i].clone().into(); - let mut is = self.0 .0.clone(); - is.push(i); - Self(ValueKey(is, self.0 .1.clone()), v) - } -} - -impl PartialEq for ValueHandle { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 + // unsafe because we panic if any common rows have different lengths + fn meet_mut_unsafe(&mut self, other: Self) -> bool { + let mut changed = false; + let mut keys_to_remove = vec![]; + for k in self.0.keys() { + if !other.0.contains_key(k) { + keys_to_remove.push(*k); + } + } + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.meet_mut(rhs); + } + } else { + keys_to_remove.push(k); + } + } + for k in keys_to_remove { + self.0.remove(&k); + changed = true; + } + changed } } -impl Eq for ValueHandle {} - -impl Hash for ValueHandle { - fn hash(&self, state: &mut I) { - self.0.hash(state); - } -} +impl PartialOrd for PartialSum { + fn partial_cmp(&self, other: &Self) -> Option { + let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); + let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); + for k in self.0.keys() { + keys1[*k] = 1; + } -/// TODO this is perhaps dodgy -/// we do not hash or compare the value, just the key -/// this means two handles with different keys, but with the same value, will -/// not compare equal. -impl Deref for ValueHandle { - type Target = Value; + for k in other.0.keys() { + keys2[*k] = 1; + } - fn deref(&self) -> &Self::Target { - self.value() + if let Some(ord) = keys1.partial_cmp(&keys2) { + if ord != Ordering::Equal { + return Some(ord); + } + } else { + return None; + } + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(&k) else { + unreachable!() + }; + match lhs.partial_cmp(rhs) { + Some(Ordering::Equal) => continue, + x => { + return x; + } + } + } + Some(Ordering::Equal) } } -/// TODO shouldn't be pub -#[derive(PartialEq, Clone, Eq)] -pub struct HashableHashMap(HashMap); - -impl std::fmt::Debug - for HashableHashMap -{ +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for HashableHashMap { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { - self.0.keys().for_each(|k| k.hash(state)); - self.0.values().for_each(|v| v.hash(state)); + for (k, v) in &self.0 { + k.hash(state); + v.hash(state); + } } } -#[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { - Bottom, - Value(ValueHandle), - PartialSum(HashableHashMap>), - Top, -} +impl TryFrom for PartialSum { + type Error = ValueHandle; -impl From for PartialValue { - fn from(v: ValueHandle) -> Self { - match v.value() { + fn try_from(value: ValueHandle) -> Result { + match value.value() { Value::Tuple { vs } => { let vec = (0..vs.len()) - .map(|i| PartialValue::from(v.index(i)).into()) + .map(|i| PartialValue::from(value.index(i)).into()) .collect(); - Self::PartialSum(HashableHashMap([(0, vec)].into_iter().collect())) + return Ok(Self([(0, vec)].into_iter().collect())); } Value::Sum { tag, values, .. } => { let vec = (0..values.len()) - .map(|i| PartialValue::from(v.index(i)).into()) + .map(|i| PartialValue::from(value.index(i)).into()) .collect(); - Self::PartialSum(HashableHashMap([(*tag, vec)].into_iter().collect())) + return Ok(Self([(*tag, vec)].into_iter().collect())); } - _ => Self::Value(v), - } + _ => () + }; + Err(value) + } +} + +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + Bottom, + Value(ValueHandle), + PartialSum(PartialSum), + Top, +} + +impl From for PartialValue { + fn from(v: ValueHandle) -> Self { + TryInto::::try_into(v).map_or_else(Self::Value, Self::PartialSum) } } +impl From for PartialValue { + fn from(v: PartialSum) -> Self { + Self::PartialSum(v) + } +} + + impl PartialValue { const BOTTOM: Self = Self::Bottom; const BOTTOM_REF: &'static Self = &Self::BOTTOM; @@ -186,11 +218,8 @@ impl PartialValue { fn assert_invariants(&self) { match self { - Self::PartialSum(HashableHashMap(hm)) => { - assert_ne!(hm.len(), 0); - for pv in hm.values().flat_map(|x| x.iter()) { - pv.assert_invariants(); - } + Self::PartialSum(ps) => { + ps.assert_variants(); } Self::Value(v) => { assert!(matches!(v.clone().into(), Self::Value(_))) @@ -209,15 +238,8 @@ impl PartialValue { pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { match self { Self::Bottom => Self::Bottom, - Self::PartialSum(HashableHashMap(hm)) => { - if let Some(row) = hm.get(&variant) { - assert!(row.len() > idx); - row[idx].clone() - } else { - // We must return top. if self were to gain this variant, we would return the element of that variant. - // We must ensure that the value return now is <= that future value - Self::Top - } + Self::PartialSum(ps) => { + ps.variant_field_value(variant, idx) } Self::Value(v) if v.tag() == variant => Self::Value(v.index(idx)), _ => Self::Top, @@ -226,33 +248,10 @@ impl PartialValue { pub fn try_into_value(self, typ: &Type) -> Result { let r = match self { - Self::Value(v) => v.value().clone(), - Self::PartialSum(HashableHashMap(hm)) => { - let err = |hm| Err(Self::PartialSum(HashableHashMap(hm))); - let Ok((k, v)) = hm.iter().exactly_one() else { - return err(hm); - }; - let TypeEnum::Sum(st) = typ.as_type_enum() else { - return err(hm); - }; - let Some(r) = st.get_variant(*k) else { - return err(hm); - }; - if v.len() != r.len() { - return err(hm); - } - - let Ok(vs) = zip_eq(v.into_iter(), r.into_iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - else { - return err(hm); - }; - - Value::sum(*k, vs, st.clone()).map_err(|_| Self::PartialSum(HashableHashMap(hm)))? - } - x => Err(x)?, - }; + Self::Value(v) => Ok(v.value().clone()), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + }?; assert_eq!(typ, &r.get_type()); Ok(r) } @@ -317,6 +316,14 @@ impl PartialValue { } } + fn value_handles_equal(&self, rhs: &ValueHandle) -> bool { + let Self::Value(lhs) = self else { unreachable!() }; + lhs == rhs + // The following is a good idea if ValueHandle gains an Eq + // instance and so does not do this check: + // || lhs.value() == rhs.value() + } + pub fn join(mut self, other: Self) -> Self { self.join_mut(other); self @@ -336,25 +343,22 @@ impl PartialValue { **new_self = other; true } - (Self::Value(h1), Self::Value(h2)) if h1 == &h2 || h1.value() == h2.value() => false, - (new_self @ Self::PartialSum(_), Self::PartialSum(HashableHashMap(hm2))) => { - let mut changed = false; - let Self::PartialSum(HashableHashMap(hm1)) = *new_self else { + (new_self @ Self::Value(_), Self::Value(h2)) => { + if new_self.value_handles_equal(&h2) { + false + } else { + **new_self = Self::Top; + true + } + } + (new_self @ Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = *new_self else { unreachable!() }; - for (k, v) in hm2 { - if let Some(row) = hm1.get_mut(&k) { - for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { - changed |= lhs.join_mut(rhs); - } - } else { - hm1.insert(k, v); - changed = true; - } - } - changed + + ps1.join_mut_unsafe(ps2) } - (new_self @ Self::Value(_), other @ Self::PartialSum(_)) => { + (new_self @ Self::Value(_), other) => { let mut old_self = other; std::mem::swap(*new_self, &mut old_self); let Self::Value(h) = old_self else { @@ -362,11 +366,11 @@ impl PartialValue { }; new_self.join_mut_value_handle(h) } - (new_self @ Self::PartialSum(_), Self::Value(h)) => new_self.join_mut_value_handle(h), - (new_self, _) => { - **new_self = Self::Top; - false - } + (new_self, Self::Value(h)) => new_self.join_mut_value_handle(h), + // (new_self, _) => { + // **new_self = Self::Top; + // false + // } }; // if changed { // println!("join new self: {:?}", s); @@ -392,32 +396,19 @@ impl PartialValue { **new_self = other; true } - (Self::Value(h1), Self::Value(h2)) if h1 == &h2 || h1.value() == h2.value() => false, - (new_self @ Self::PartialSum(_), Self::PartialSum(HashableHashMap(hm2))) => { - let mut changed = false; - let Self::PartialSum(HashableHashMap(hm1)) = *new_self else { + (new_self @ Self::Value(_), Self::Value(h2)) => { + if new_self.value_handles_equal(&h2) { + false + } else { + **new_self = Self::Bottom; + true + } + } + (new_self @ Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = *new_self else { unreachable!() }; - let mut keys_to_remove = vec![]; - for k in hm1.keys() { - if !hm2.contains_key(k) { - keys_to_remove.push(*k); - } - } - for (k, v) in hm2 { - if let Some(row) = hm1.get_mut(&k) { - for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { - changed |= lhs.meet_mut(rhs); - } - } else { - keys_to_remove.push(k); - } - } - for k in keys_to_remove { - hm1.remove(&k); - changed = true; - } - changed + ps1.meet_mut_unsafe(ps2) } (new_self @ Self::Value(_), other @ Self::PartialSum(_)) => { let mut old_self = other; @@ -428,10 +419,10 @@ impl PartialValue { new_self.meet_mut_value_handle(h) } (s @ Self::PartialSum(_), Self::Value(h)) => s.meet_mut_value_handle(h), - (new_self, _) => { - **new_self = Self::Bottom; - false - } + // (new_self, _) => { + // **new_self = Self::Bottom; + // false + // } }; // if changed { // println!("join new self: {:?}", s); @@ -448,9 +439,7 @@ impl PartialValue { } pub fn variant(tag: usize, values: impl IntoIterator) -> Self { - Self::PartialSum(HashableHashMap( - [(tag, values.into_iter().collect())].into_iter().collect(), - )) + PartialSum::variant(tag, values).into() } pub fn unit() -> Self { @@ -468,38 +457,8 @@ impl PartialOrd for PartialValue { (_, Self::Bottom) => Some(Ordering::Greater), (Self::Top, _) => Some(Ordering::Greater), (_, Self::Top) => Some(Ordering::Less), - (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), - (Self::PartialSum(HashableHashMap(hm1)), Self::PartialSum(HashableHashMap(hm2))) => { - let max_key = hm1.keys().chain(hm2.keys()).copied().max().unwrap(); - let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); - for k in hm1.keys() { - keys1[*k] = 1; - } - - for k in hm2.keys() { - keys2[*k] = 1; - } - - if let Some(ord) = keys1.partial_cmp(&keys2) { - if ord != Ordering::Equal { - return Some(ord); - } - } else { - return None; - } - for (k, lhs) in hm1 { - let Some(rhs) = hm2.get(&k) else { - unreachable!() - }; - match lhs.partial_cmp(rhs) { - Some(Ordering::Equal) => continue, - x => { - return x; - } - } - } - Some(Ordering::Equal) - } + (Self::Value(_), Self::Value(v2)) => self.value_handles_equal(v2).then_some(Ordering::Equal), + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } } @@ -513,6 +472,8 @@ mod test { use lazy_static::lazy_static; use proptest::prelude::*; + use crate::{std_extensions::arithmetic::int_types::{self, INT_TYPES, LOG_WIDTH_BOUND}, types::{CustomType, Type, TypeEnum}}; + use super::{PartialValue, ValueHandle}; impl Arbitrary for ValueHandle { type Parameters = (); @@ -526,51 +487,120 @@ mod test { } #[derive(Debug, PartialEq, Eq, Clone)] - struct UnarySumType(usize, Vec>>); + enum TestSumLeafType { + Int(Type), + Unit, + } - lazy_static! { - static ref UNARY_SUM_TYPE_LEAF: UnarySumType = UnarySumType::new([]); + impl Arbitrary for TestSumLeafType { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters,) -> Self::Strategy { + let int_strat = (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); + prop_oneof![ + Just(TestSumLeafType::Unit), + int_strat + ].boxed() + } } - impl UnarySumType { - pub fn new(vs: impl IntoIterator>>) -> Self { + #[derive(Debug, PartialEq, Eq, Clone)] + enum TestSumType { + Branch(usize, Vec>>), + Leaf(TestSumLeafType) + } + + impl TestSumType { + const UNIT: TestSumLeafType = TestSumLeafType::Unit; + + pub fn leaf(v: Type) -> Self { + TestSumType::Leaf(TestSumLeafType::Int(v)) + } + + pub fn branch(vs: impl IntoIterator>>) -> Self { let vec = vs.into_iter().collect_vec(); let depth: usize = vec .iter() .flat_map(|x| x.iter()) - .map(|x| x.0 + 1) + .map(|x| x.depth() + 1) .max() .unwrap_or(0); - Self(depth, vec.into()).into() + Self::Branch(depth, vec.into()).into() + } + + fn depth(&self) -> usize { + match self { + TestSumType::Branch(x, _) => *x, + TestSumType::Leaf(_) => 0, + } } fn is_leaf(&self) -> bool { - self.0 == 0 + self.depth() == 0 } fn assert_invariants(&self) { - if self.is_leaf() { - assert!(self.1.iter().all(Vec::is_empty)); - } else { - for v in self.1.iter().flat_map(|x| x.iter()) { - assert!(v.0 < self.0); - v.assert_invariants() + match self { + TestSumType::Branch(d, sop) => { + assert!(!sop.is_empty(), "No variants"); + for v in sop.iter().flat_map(|x| x.iter()) { + assert!(v.depth() < *d); + v.assert_invariants(); + } } + TestSumType::Leaf(TestSumLeafType::Int(t)) => { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); + } + }, + _ => () } } fn select(self) -> impl Strategy>)>> { - if self.is_leaf() { - Just(None).boxed() - } else { - any::() - .prop_map(move |i| { - let index = i.index(self.1.len()); - Some((index, self.1[index].clone())) - }) - .boxed() + match self { + TestSumType::Branch(_, sop) => { + any::() + .prop_map(move |i| { + let index = i.index(sop.len()); + Some((index, sop[index].clone())) + }) + .boxed() + } + TestSumType::Leaf(_) => Just(None).boxed() + } } + + // fn type_check(&self, pv: PartialValue) -> bool { + // match (self,pv) { + // (_, PartialValue::Bottom) | PartialValue::Top => true, + // (_, PartialValue::Value(_)) => todo!(), + // (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + // for (k,v) in ps.0 { + // if k >= sop.len() { + // return false + // } + + // } + // } + // (TestSumType::Branch(_, _), PartialValue::Top) => todo!(), + // (TestSumType::Leaf(_), PartialValue::Bottom) => todo!(), + // (TestSumType::Leaf(_), PartialValue::Value(_)) => todo!(), + // (TestSumType::Leaf(_), PartialValue::PartialSum(_)) => todo!(), + // (TestSumType::Leaf(_), PartialValue::Top) => todo!(), + // } + + // } + } + + impl From for TestSumType { + fn from(value: TestSumLeafType) -> Self { + Self::Leaf(value) + } } #[derive(Clone, PartialEq, Eq, Debug)] @@ -596,7 +626,7 @@ mod test { } } - impl Arbitrary for UnarySumType { + impl Arbitrary for TestSumType { type Parameters = UnarySumTypeParams; type Strategy = BoxedStrategy; fn arbitrary_with( @@ -606,7 +636,7 @@ mod test { }: Self::Parameters, ) -> Self::Strategy { if depth == 0 { - Just(UNARY_SUM_TYPE_LEAF.clone()).boxed() + any::().prop_map_into().boxed() } else { (0..depth) .prop_flat_map(move |d| { @@ -617,7 +647,7 @@ mod test { ), 1..=branch_width, ) - .prop_map(UnarySumType::new) + .prop_map(TestSumType::branch) }) .boxed() } @@ -626,17 +656,17 @@ mod test { proptest! { #[test] - fn unary_sum_type_valid(ust: UnarySumType) { + fn unary_sum_type_valid(ust: TestSumType) { ust.assert_invariants(); } } - fn any_partial_value_of_type(ust: UnarySumType) -> impl Strategy { + fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { ust.select().prop_flat_map(|x| { if let Some((index, usts)) = x { let pvs = usts .into_iter() - .map(|x| any_partial_value_of_type(Arc::::unwrap_or_clone(x))) + .map(|x| any_partial_value_of_type(Arc::::unwrap_or_clone(x))) .collect_vec(); pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) .boxed() @@ -647,9 +677,9 @@ mod test { } fn any_partial_value_with( - params: ::Parameters, + params: ::Parameters, ) -> impl Strategy { - any_with::(params).prop_flat_map(any_partial_value_of_type) + any_with::(params).prop_flat_map(any_partial_value_of_type) } fn any_partial_value() -> impl Strategy { @@ -657,7 +687,7 @@ mod test { } fn any_partial_values() -> impl Strategy { - any::().prop_flat_map(|ust| { + any::().prop_flat_map(|ust| { TryInto::<[_; N]>::try_into( (0..N) .map(|_| any_partial_value_of_type(ust.clone())) @@ -667,7 +697,18 @@ mod test { }) } + fn any_typed_partial_value() -> impl Strategy { + any::().prop_flat_map(|t| { + any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(),v)) + }) + } + proptest! { + // #[test] + // fn partial_value_type((tst, pv) in any_typed_partial_value()) { + // prop_assert!(tst.type_check(pv)) + // } + // todo: ValidHandle is valid // todo: ValidHandle eq is an equivalence relation diff --git a/hugr-core/src/partial_value/value_handle.rs b/hugr-core/src/partial_value/value_handle.rs new file mode 100644 index 000000000..7587ed2d8 --- /dev/null +++ b/hugr-core/src/partial_value/value_handle.rs @@ -0,0 +1,225 @@ +use std::any::Any; +use std::ops::Deref; +use std::sync::Arc; +use std::hash::{DefaultHasher, Hash, Hasher}; + +use downcast_rs::Downcast; +use itertools::Either; + +use crate::ops::Value; +use crate::std_extensions::arithmetic::int_types::ConstInt; +use crate::Node; + +pub trait ValueName: std::fmt::Debug + Downcast + Any { + fn hash(&self) -> u64; + fn eq(&self, other: &dyn ValueName) -> bool; +} + +fn hash_hash(x: &impl Hash) -> u64 { + let mut hasher = DefaultHasher::new(); + x.hash(&mut hasher); + hasher.finish() +} + +fn value_name_eq(x: &T, other: &dyn ValueName) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + x == other + } else { + false + } +} + +impl ValueName for String { + fn hash(&self) -> u64 { + hash_hash(self) + } + + fn eq(&self, other: &dyn ValueName) -> bool { + value_name_eq(self, other) + } +} + +impl ValueName for ConstInt { + fn hash(&self) -> u64 { + hash_hash(self) + } + + fn eq(&self, other: &dyn ValueName) -> bool { + value_name_eq(self, other) + } +} + +#[derive(Clone, Debug)] +pub struct ValueKey(Vec, Either>); + +impl PartialEq for ValueKey { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + && match (&self.1, &other.1) { + (Either::Left(ref n1), Either::Left(ref n2)) => n1 == n2, + (Either::Right(ref v1), Either::Right(ref v2)) => v1.eq(v2.as_ref()), + _ => false, + } + } +} + +impl Eq for ValueKey {} + +impl Hash for ValueKey { + fn hash(&self, state: &mut H) { + self.0.hash(state); + match &self.1 { + Either::Left(n) => (0,n).hash(state), + Either::Right(v) => (1,v.hash()).hash(state), + } + } +} + +impl From for ValueKey { + fn from(n: Node) -> Self { + Self(vec![], Either::Left(n)) + } +} + +impl ValueKey { + pub fn new(k: impl ValueName) -> Self{ + Self(vec![], Either::Right(Arc::new(k))) + } + + pub fn index(self, i: usize) -> Self { + let mut is = self.0; + is.push(i); + Self(is, self.1) + } +} + +#[derive(Clone, Debug)] +pub struct ValueHandle(ValueKey, Arc); + +impl ValueHandle { + pub fn new(key: ValueKey, value: Arc) -> Self { + Self(key, value) + } + + pub fn value(&self) -> &Value { + self.1.as_ref() + } + + pub fn is_compound(&self) -> bool { + match self.value() { + Value::Sum { .. } | Value::Tuple { .. } => true, + _ => false, + } + } + + pub fn num_fields(&self) -> usize { + assert!(self.is_compound(), "ValueHandle::num_fields called on non-Sum, non-Tuple value: {:#?}", self); + match self.value() { + Value::Sum { values, .. } => values.len(), + | Value::Tuple { vs } => vs.len(), + _ => unreachable!(), + } + } + + pub fn tag(&self) -> usize { + assert!(self.is_compound(), "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", self); + match self.value() { + Value::Sum { tag, .. } => *tag, + Value::Tuple { .. } => 0, + _ => unreachable!(), + } + } + + pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { + assert!(i < self.num_fields(), "ValueHandle::index called with out-of-bounds index {}: {:#?}", i, &self); + let vs = match self.value() { + Value::Sum { values, .. } => values, + Value::Tuple { vs, .. } => vs, + _ => unreachable!() + }; + let v = vs[i].clone().into(); + Self(self.0.clone().index(i), v) + } +} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + // If the keys are equal, we return true since the values must have the + // same provenance, and so be equal. If the keys are different but the + // values are equal, we could return true if we didn't impl Eq, but + // since we do impl Eq, the Hash contract prohibits us from having equal + // values with different hashes. + let r = self.0 == other.0; + if r { + debug_assert_eq!(self.get_type(), other.get_type()); + } + r + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + self.0.hash(state); + } +} + +/// TODO this is perhaps dodgy +/// we do not hash or compare the value, just the key +/// this means two handles with different keys, but with the same value, will +/// not compare equal. +impl Deref for ValueHandle { + type Target = Value; + + fn deref(&self) -> &Self::Target { + self.value() + } +} + +#[cfg(test)] +mod test { + use crate::{ops::constant::CustomConst as _, types::SumType}; + + use super::*; + + #[test] + fn value_key_eq() { + let k1 = ValueKey::new("foo".to_string()); + let k2 = ValueKey::new("foo".to_string()); + let k3 = ValueKey::new("bar".to_string()); + + assert_eq!(k1, k2); + assert_ne!(k1, k3); + + let k4: ValueKey = From::::from(portgraph::NodeIndex::new(1).into()); + let k5 = From::::from(portgraph::NodeIndex::new(1).into()); + let k6 = From::::from(portgraph::NodeIndex::new(2).into()); + + assert_eq!(&k4,&k5); + assert_ne!(&k4,&k6); + + let k7 = k5.clone().index(3); + let k4 = k4.index(3); + + assert_eq!(&k4,&k7); + + let k5 = k5.index(2); + + assert_ne!(&k5,&k7); + } + + #[test] + fn value_handle_eq() { + let k_i = ConstInt::new_u(4,2).unwrap(); + let subject_val = Arc::new(Value::sum(0, [k_i.clone().into()], SumType::new([vec![k_i.get_type()], vec![]])).unwrap()); + + let k1 = ValueKey::new("foo".to_string()); + let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); + let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); + + // we do not compare the value, just the key + assert_ne!(v1.index(0), v2); + assert_eq!(v1.index(0).value(), v2.value()); + } +} diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 5ff2abea3..fda85b4b6 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -80,7 +80,7 @@ const fn type_arg(log_width: u8) -> TypeArg { } /// An integer (either signed or unsigned) -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)] pub struct ConstInt { log_width: u8, // We always use a u64 for the value. The interpretation is: diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 06d5ea253..2c7c624a2 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -6,7 +6,6 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use either::Either; -use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::{OpTag, OpTrait, Value}; use hugr_core::partial_value::{PartialValue, ValueHandle}; use hugr_core::types::{FunctionType, SumType, Type, TypeEnum, TypeRow}; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index e01b7d515..4d981d0db 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -5,12 +5,10 @@ use std::ops::Deref; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -use ascent::Lattice; -use either::Either; -use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::ops::Value; use hugr_core::partial_value::{ValueHandle, ValueKey}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use hugr_core::Node; #[derive(Clone)] pub struct ValueCache(HashMap>); From 7ac9ec8417c7a3933db23a3db105d25e04a520b3 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Jun 2024 15:49:40 +0100 Subject: [PATCH 07/12] wip --- hugr-core/src/partial_value.rs | 281 +----------------- .../std_extensions/arithmetic/int_types.rs | 2 +- 2 files changed, 3 insertions(+), 280 deletions(-) diff --git a/hugr-core/src/partial_value.rs b/hugr-core/src/partial_value.rs index ed8e9adac..aed009b78 100644 --- a/hugr-core/src/partial_value.rs +++ b/hugr-core/src/partial_value.rs @@ -18,6 +18,7 @@ pub use value_handle::{ValueKey, ValueHandle}; pub struct PartialSum(HashMap>); impl PartialSum { + pub fn unit() -> Self { Self::variant(0,[]) } pub fn variant(tag: usize, values: impl IntoIterator) -> Self { Self([(tag, values.into_iter().collect())].into_iter().collect()) } @@ -465,282 +466,4 @@ impl PartialOrd for PartialValue { } #[cfg(test)] -mod test { - use std::sync::Arc; - - use itertools::Itertools as _; - use lazy_static::lazy_static; - use proptest::prelude::*; - - use crate::{std_extensions::arithmetic::int_types::{self, INT_TYPES, LOG_WIDTH_BOUND}, types::{CustomType, Type, TypeEnum}}; - - use super::{PartialValue, ValueHandle}; - impl Arbitrary for ValueHandle { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - // prop_oneof![ - - // ] - todo!() - } - } - - #[derive(Debug, PartialEq, Eq, Clone)] - enum TestSumLeafType { - Int(Type), - Unit, - } - - impl Arbitrary for TestSumLeafType { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters,) -> Self::Strategy { - let int_strat = (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); - prop_oneof![ - Just(TestSumLeafType::Unit), - int_strat - ].boxed() - } - } - - #[derive(Debug, PartialEq, Eq, Clone)] - enum TestSumType { - Branch(usize, Vec>>), - Leaf(TestSumLeafType) - } - - impl TestSumType { - const UNIT: TestSumLeafType = TestSumLeafType::Unit; - - pub fn leaf(v: Type) -> Self { - TestSumType::Leaf(TestSumLeafType::Int(v)) - } - - pub fn branch(vs: impl IntoIterator>>) -> Self { - let vec = vs.into_iter().collect_vec(); - let depth: usize = vec - .iter() - .flat_map(|x| x.iter()) - .map(|x| x.depth() + 1) - .max() - .unwrap_or(0); - Self::Branch(depth, vec.into()).into() - } - - fn depth(&self) -> usize { - match self { - TestSumType::Branch(x, _) => *x, - TestSumType::Leaf(_) => 0, - } - } - - fn is_leaf(&self) -> bool { - self.depth() == 0 - } - - fn assert_invariants(&self) { - match self { - TestSumType::Branch(d, sop) => { - assert!(!sop.is_empty(), "No variants"); - for v in sop.iter().flat_map(|x| x.iter()) { - assert!(v.depth() < *d); - v.assert_invariants(); - } - } - TestSumType::Leaf(TestSumLeafType::Int(t)) => { - if let TypeEnum::Extension(ct) = t.as_type_enum() { - assert_eq!("int", ct.name()); - assert_eq!(&int_types::EXTENSION_ID, ct.extension()); - } else { - panic!("Expected int type, got {:#?}", t); - } - }, - _ => () - } - } - - fn select(self) -> impl Strategy>)>> { - match self { - TestSumType::Branch(_, sop) => { - any::() - .prop_map(move |i| { - let index = i.index(sop.len()); - Some((index, sop[index].clone())) - }) - .boxed() - } - TestSumType::Leaf(_) => Just(None).boxed() - - } - } - - // fn type_check(&self, pv: PartialValue) -> bool { - // match (self,pv) { - // (_, PartialValue::Bottom) | PartialValue::Top => true, - // (_, PartialValue::Value(_)) => todo!(), - // (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { - // for (k,v) in ps.0 { - // if k >= sop.len() { - // return false - // } - - // } - // } - // (TestSumType::Branch(_, _), PartialValue::Top) => todo!(), - // (TestSumType::Leaf(_), PartialValue::Bottom) => todo!(), - // (TestSumType::Leaf(_), PartialValue::Value(_)) => todo!(), - // (TestSumType::Leaf(_), PartialValue::PartialSum(_)) => todo!(), - // (TestSumType::Leaf(_), PartialValue::Top) => todo!(), - // } - - // } - } - - impl From for TestSumType { - fn from(value: TestSumLeafType) -> Self { - Self::Leaf(value) - } - } - - #[derive(Clone, PartialEq, Eq, Debug)] - struct UnarySumTypeParams { - depth: usize, - branch_width: usize, - } - - impl UnarySumTypeParams { - pub fn descend(mut self, d: usize) -> Self { - assert!(d < self.depth); - self.depth = d; - self - } - } - - impl Default for UnarySumTypeParams { - fn default() -> Self { - Self { - depth: 3, - branch_width: 3, - } - } - } - - impl Arbitrary for TestSumType { - type Parameters = UnarySumTypeParams; - type Strategy = BoxedStrategy; - fn arbitrary_with( - params @ UnarySumTypeParams { - depth, - branch_width, - }: Self::Parameters, - ) -> Self::Strategy { - if depth == 0 { - any::().prop_map_into().boxed() - } else { - (0..depth) - .prop_flat_map(move |d| { - prop::collection::vec( - prop::collection::vec( - any_with::(params.clone().descend(d)).prop_map_into(), - 0..branch_width, - ), - 1..=branch_width, - ) - .prop_map(TestSumType::branch) - }) - .boxed() - } - } - } - - proptest! { - #[test] - fn unary_sum_type_valid(ust: TestSumType) { - ust.assert_invariants(); - } - } - - fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { - ust.select().prop_flat_map(|x| { - if let Some((index, usts)) = x { - let pvs = usts - .into_iter() - .map(|x| any_partial_value_of_type(Arc::::unwrap_or_clone(x))) - .collect_vec(); - pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) - .boxed() - } else { - Just(PartialValue::unit()).boxed() - } - }) - } - - fn any_partial_value_with( - params: ::Parameters, - ) -> impl Strategy { - any_with::(params).prop_flat_map(any_partial_value_of_type) - } - - fn any_partial_value() -> impl Strategy { - any_partial_value_with(Default::default()) - } - - fn any_partial_values() -> impl Strategy { - any::().prop_flat_map(|ust| { - TryInto::<[_; N]>::try_into( - (0..N) - .map(|_| any_partial_value_of_type(ust.clone())) - .collect_vec(), - ) - .unwrap() - }) - } - - fn any_typed_partial_value() -> impl Strategy { - any::().prop_flat_map(|t| { - any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(),v)) - }) - } - - proptest! { - // #[test] - // fn partial_value_type((tst, pv) in any_typed_partial_value()) { - // prop_assert!(tst.type_check(pv)) - // } - - // todo: ValidHandle is valid - // todo: ValidHandle eq is an equivalence relation - - // todo: PartialValue PartialOrd is transitive - // todo: PartialValue eq is an equivalence relation - #[test] - fn partial_value_valid(pv in any_partial_value()) { - pv.assert_invariants(); - } - - #[test] - fn bounded_lattice(v in any_partial_value()) { - prop_assert!(&v <= &PartialValue::Top); - prop_assert!(&v >= &PartialValue::Bottom); - } - - #[test] - fn lattice_changed(v1 in any_partial_value()) { - let mut subject = v1.clone(); - assert!(!subject.join_mut(v1.clone())); - assert!(!subject.meet_mut(v1.clone())); - } - - #[test] - fn lattice([v1,v2] in any_partial_values()) { - let meet = v1.clone().meet(v2.clone()); - prop_assert!(&meet <= &v1, "meet not less <=: {:#?}", &meet); - prop_assert!(&meet <= &v2, "meet not less <=: {:#?}", &meet); - - let join = v1.clone().join(v2.clone()); - prop_assert!(&join >= &v1, "join not >=: {:#?}", &join); - prop_assert!(&join >= &v2, "join not >=: {:#?}", &join); - } - } -} +mod test; diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index fda85b4b6..d79a4b84f 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -63,7 +63,7 @@ pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat({ /// Get the log width of the specified type argument or error if the argument /// is invalid. -pub(super) fn get_log_width(arg: &TypeArg) -> Result { +pub(crate) fn get_log_width(arg: &TypeArg) -> Result { match arg { TypeArg::BoundedNat { n } if is_valid_log_width(*n as u8) => Ok(*n as u8), _ => Err(TypeArgError::TypeMismatch { From 0a5c7299c1584738b7a24d095c47312cb96e9d2c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Jun 2024 06:16:25 +0100 Subject: [PATCH 08/12] wip --- hugr-core/src/partial_value.rs | 74 +++--- hugr-core/src/partial_value/test.rs | 328 +++++++++++++++++++++++++ hugr-passes/src/const_fold2/datalog.rs | 207 +++++++--------- 3 files changed, 467 insertions(+), 142 deletions(-) create mode 100644 hugr-core/src/partial_value/test.rs diff --git a/hugr-core/src/partial_value.rs b/hugr-core/src/partial_value.rs index aed009b78..3ab65f8c2 100644 --- a/hugr-core/src/partial_value.rs +++ b/hugr-core/src/partial_value.rs @@ -39,9 +39,7 @@ impl PartialSum { assert!(row.len() > idx); row[idx].clone() } else { - // We must return top. if self were to gain this variant, we would return the element of that variant. - // We must ensure that the value return now is <= that future value - PartialValue::top() + PartialValue::bottom() } } @@ -110,6 +108,10 @@ impl PartialSum { } changed } + + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } } impl PartialOrd for PartialSum { @@ -206,16 +208,16 @@ impl From for PartialValue { impl PartialValue { - const BOTTOM: Self = Self::Bottom; - const BOTTOM_REF: &'static Self = &Self::BOTTOM; + // const BOTTOM: Self = Self::Bottom; + // const BOTTOM_REF: &'static Self = &Self::BOTTOM; - fn initialised(&self) -> bool { - !self.is_top() - } + // fn initialised(&self) -> bool { + // !self.is_top() + // } - fn is_top(&self) -> bool { - self == &PartialValue::Top - } + // fn is_top(&self) -> bool { + // self == &PartialValue::Top + // } fn assert_invariants(&self) { match self { @@ -229,23 +231,6 @@ impl PartialValue { } } - /// TODO docs - /// just delegate to variant_field_value - pub fn tuple_field_value(&self, idx: usize) -> Self { - self.variant_field_value(0, idx) - } - - /// TODO docs - pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { - match self { - Self::Bottom => Self::Bottom, - Self::PartialSum(ps) => { - ps.variant_field_value(variant, idx) - } - Self::Value(v) if v.tag() == variant => Self::Value(v.index(idx)), - _ => Self::Top, - } - } pub fn try_into_value(self, typ: &Type) -> Result { let r = match self { @@ -446,6 +431,39 @@ impl PartialValue { pub fn unit() -> Self { Self::variant(0, []) } + + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom => false, + PartialValue::Value(v) => v.tag() == tag, // can never be a sum or tuple + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// TODO docs + /// just delegate to variant_field_value + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO docs + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + match self { + Self::Bottom => Self::Bottom, + Self::PartialSum(ps) => { + ps.variant_field_value(variant, idx) + } + Self::Value(v) => { + if v.tag() == variant { + Self::Value(v.index(idx)) + } else { + Self::Bottom + } + }, + Self::Top => Self::Top, + } + } } impl PartialOrd for PartialValue { diff --git a/hugr-core/src/partial_value/test.rs b/hugr-core/src/partial_value/test.rs new file mode 100644 index 000000000..427f7983b --- /dev/null +++ b/hugr-core/src/partial_value/test.rs @@ -0,0 +1,328 @@ +use std::sync::Arc; + +use itertools::{zip_eq, Either, Itertools as _}; +use lazy_static::lazy_static; +use proptest::prelude::*; + +use crate::{ + ops::Value, std_extensions::arithmetic::int_types::{self, get_log_width, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, types::{CustomType, Type, TypeEnum} +}; + +use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; +impl Arbitrary for ValueHandle { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + // prop_oneof![ + + // ] + todo!() + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum TestSumLeafType { + Int(Type), + Unit, +} + +impl TestSumLeafType { + fn assert_invariants(&self) { + match self { + Self::Int(t) => { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); + } + } + _ => () + } + } + + fn get_type(&self) -> Type { + match self { + Self::Int(t) => t.clone(), + Self::Unit => Type::UNIT, + } + } + + fn type_check(&self, ps: &PartialSum) -> bool { + match self { + Self::Int(_) => false, + Self::Unit => { + if let Ok((0, v)) = ps.0.iter().exactly_one() { + v.is_empty() + } else { + false + } + } + } + } + + fn partial_value_strategy(self) -> impl Strategy { + match self { + Self::Int(t) => { + let TypeEnum::Extension(ct) = t.as_type_enum() else { unreachable!() }; + let lw = get_log_width(&ct.args()[0]).unwrap(); + (0u64..(1 << (2u64.pow(lw as u32) - 1))).prop_map(move |x| { + let ki = ConstInt::new_u(lw, x).unwrap(); + ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() + }).boxed() + }, + Self::Unit => Just(PartialSum::unit().into()).boxed(), + } + } +} + +impl Arbitrary for TestSumLeafType { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + let int_strat = (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); + prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum TestSumType { + Branch(usize, Vec>>), + Leaf(TestSumLeafType), +} + +impl TestSumType { + const UNIT: TestSumLeafType = TestSumLeafType::Unit; + + fn leaf(v: Type) -> Self { + TestSumType::Leaf(TestSumLeafType::Int(v)) + } + + fn branch(vs: impl IntoIterator>>) -> Self { + let vec = vs.into_iter().collect_vec(); + let depth: usize = vec + .iter() + .flat_map(|x| x.iter()) + .map(|x| x.depth() + 1) + .max() + .unwrap_or(0); + Self::Branch(depth, vec.into()).into() + } + + fn depth(&self) -> usize { + match self { + TestSumType::Branch(x, _) => *x, + TestSumType::Leaf(_) => 0, + } + } + + fn is_leaf(&self) -> bool { + self.depth() == 0 + } + + fn assert_invariants(&self) { + match self { + TestSumType::Branch(d, sop) => { + assert!(!sop.is_empty(), "No variants"); + for v in sop.iter().flat_map(|x| x.iter()) { + assert!(v.depth() < *d); + v.assert_invariants(); + } + } + TestSumType::Leaf(l) => { + l.assert_invariants(); + } + _ => (), + } + } + + fn select(self) -> impl Strategy>)>> { + match self { + TestSumType::Branch(_, sop) => any::() + .prop_map(move |i| { + let index = i.index(sop.len()); + Either::Right((index, sop[index].clone())) + }) + .boxed(), + TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), + } + } + + fn get_type(&self) -> Type { + match self { + TestSumType::Branch(_, sop) => Type::new_sum( + sop.iter() + .map(|row| row.iter().map(|x| x.get_type()).collect_vec().into()), + ), + TestSumType::Leaf(l) => l.get_type(), + } + } + + fn type_check(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), + (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(&rhs)) { + return false; + } + } + true + } + (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(&ps), + } + } +} + +impl From for TestSumType { + fn from(value: TestSumLeafType) -> Self { + Self::Leaf(value) + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +struct UnarySumTypeParams { + depth: usize, + branch_width: usize, +} + +impl UnarySumTypeParams { + pub fn descend(mut self, d: usize) -> Self { + assert!(d < self.depth); + self.depth = d; + self + } +} + +impl Default for UnarySumTypeParams { + fn default() -> Self { + Self { + depth: 3, + branch_width: 3, + } + } +} + +impl Arbitrary for TestSumType { + type Parameters = UnarySumTypeParams; + type Strategy = BoxedStrategy; + fn arbitrary_with( + params @ UnarySumTypeParams { + depth, + branch_width, + }: Self::Parameters, + ) -> Self::Strategy { + if depth == 0 { + any::().prop_map_into().boxed() + } else { + (0..depth) + .prop_flat_map(move |d| { + prop::collection::vec( + prop::collection::vec( + any_with::(params.clone().descend(d)).prop_map_into(), + 0..branch_width, + ), + 1..=branch_width, + ) + .prop_map(TestSumType::branch) + }) + .boxed() + } + } +} + +proptest! { + #[test] + fn unary_sum_type_valid(ust: TestSumType) { + ust.assert_invariants(); + } +} + +fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { + ust.select().prop_flat_map(|x| match x { + Either::Left(l) => l.partial_value_strategy().boxed(), + Either::Right((index, usts)) => { + let pvs = usts + .into_iter() + .map(|x| any_partial_value_of_type(Arc::::unwrap_or_clone(x))) + .collect_vec(); + pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + .boxed() + } + }) +} + +fn any_partial_value_with( + params: ::Parameters, +) -> impl Strategy { + any_with::(params).prop_flat_map(any_partial_value_of_type) +} + +fn any_partial_value() -> impl Strategy { + any_partial_value_with(Default::default()) +} + +fn any_partial_values() -> impl Strategy { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(ust.clone())) + .collect_vec(), + ) + .unwrap() + }) +} + +fn any_typed_partial_value() -> impl Strategy { + any::() + .prop_flat_map(|t| any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v))) +} + +proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.type_check(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(&v <= &PartialValue::Top); + prop_assert!(&v >= &PartialValue::Bottom); + } + + #[test] + fn lattice_changed(v1 in any_partial_value()) { + let mut subject = v1.clone(); + assert!(!subject.join_mut(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(&meet <= &v1, "meet not less <=: {:#?}", &meet); + prop_assert!(&meet <= &v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(&join >= &v1, "join not >=: {:#?}", &join); + prop_assert!(&join >= &v2, "join not >=: {:#?}", &join); + } +} diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 2c7c624a2..25c87b460 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,4 +1,4 @@ -use ascent::lattice::{BoundedLattice, Dual, Lattice}; +use ascent::lattice::{BoundedLattice, Dual, Lattice, ord_lattice::OrdLattice}; use delegate::delegate; use itertools::{zip_eq, Itertools}; use std::collections::HashMap; @@ -8,7 +8,7 @@ use std::sync::Arc; use either::Either; use hugr_core::ops::{OpTag, OpTrait, Value}; use hugr_core::partial_value::{PartialValue, ValueHandle}; -use hugr_core::types::{FunctionType, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{EdgeKind, FunctionType, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; @@ -34,6 +34,10 @@ impl PV { fn variant_field_value(&self, variant: usize, idx: usize) -> Self { self.0.variant_field_value(variant, idx).into() } + + fn supports_tag(&self, tag: usize) -> bool { + self.0.supports_tag(tag) + } } impl From for PV { @@ -89,7 +93,7 @@ impl ValueRow { Self::singleton(r.len(), idx, v) } - fn top_from_row(r: &TypeRow) -> Self { + fn bottom_from_row(r: &TypeRow) -> Self { Self::new(r.len()) } @@ -101,9 +105,9 @@ impl ValueRow { zip_eq(value_inputs(context, n), self.0.iter()) } - fn initialised(&self) -> bool { - self.0.iter().all(|x| x != &PV::top()) - } + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } } impl Lattice for ValueRow { @@ -149,7 +153,7 @@ impl IntoIterator for ValueRow { type Ctx<'a, H> = DataflowContext<'a, H>; type ArcCtx<'a, H> = Arc>; -fn top_row<'a, H: HugrView>(context: &Ctx<'a, H>, n: Node) -> ValueRow { +fn bottom_row<'a, H: HugrView>(context: &Ctx<'a, H>, n: Node) -> ValueRow { if let Some(sig) = context.hugr().signature(n) { ValueRow::new(sig.input_count()) } else { @@ -224,84 +228,51 @@ fn tail_loop_worker<'b, 'a, H: HugrView>( output_p: IncomingPort, control_variant: usize, v: &'b PV, -) -> impl Iterator + 'b { +) -> impl Iterator + 'b where 'a: 'b { let tail_loop_op = context.get_optype(n).as_tail_loop().unwrap(); let num_variant_vals = if control_variant == 0 { tail_loop_op.just_inputs.len() } else { tail_loop_op.just_outputs.len() }; + let hugr = context.hugr(); if output_p.index() == 0 { Either::Left( (0..num_variant_vals) .map(move |i| (i.into(), v.variant_field_value(control_variant, i))), ) } else { - Either::Right(std::iter::once(( - (num_variant_vals + output_p.index()).into(), - v.clone(), - ))) - } + let v = if v.supports_tag(control_variant) { + v.clone() + } else { + PV::bottom() + }; + Either::Right(std::iter::once(((num_variant_vals + output_p.index() - 1).into(), v))) + }.inspect(move |x| assert!(matches!(hugr.get_optype(n).port_kind(x.0), Some(EdgeKind::Value(_))))) } -#[derive(PartialEq, Eq, PartialOrd, Hash, Debug, Clone)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone)] pub enum TailLoopTermination { - SingleIteration, - Unknown, NeverTerminates, + SingleIteration, + Terminates, } -impl Lattice for TailLoopTermination { - fn meet_mut(&mut self, other: Self) -> bool { - match (self, other) { - (Self::SingleIteration, _) => false, - (s, o @ Self::SingleIteration) => { - *s = o; - true - } - (Self::Unknown, _) => false, - (s, o @ Self::Unknown) => { - *s = o; - true - } - _ => false, - } - } - - fn join_mut(&mut self, other: Self) -> bool { - match (self, other) { - (Self::NeverTerminates, _) => false, - (s, o @ Self::NeverTerminates) => { - *s = o; - true - } - (Self::Unknown, _) => false, - (s, o @ Self::Unknown) => { - *s = o; - true - } - _ => false, +impl TailLoopTermination { + fn from_control_value(v: &PV) -> Self { + if v.supports_tag(1) && !v.supports_tag(0) { + Self::SingleIteration + } else if v.supports_tag(1) { + Self::Terminates + } else { + Self::NeverTerminates } } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } } -impl BoundedLattice for TailLoopTermination { - fn bottom() -> Self { - Self::NeverTerminates - } - - fn top() -> Self { - Self::SingleIteration +impl From for OrdLattice { + fn from(value: TailLoopTermination) -> Self { + Self(value) } } @@ -330,7 +301,7 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, top_row(c, *n)) <-- node(c, n); + node_in_value_row(c, n, bottom_row(c, *n)) <-- node(c, n); node_in_value_row(c, n, singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); // LoadConstant @@ -374,13 +345,13 @@ ascent::ascent! { tail_loop_io_node(c,tl,n, io) <-- tail_loop_node(c,tl), if let Some([i,o]) = c.get_io(*tl), for (n,io) in [(i,IO::Input), (o, IO::Output)]; - lattice tail_loop_termination(ArcCtx<'a,H>,Node,Dual); - tail_loop_termination(c,n,Dual::top()) <-- tail_loop_node(c,n); // inputs of tail loop propagate to Input node of child region out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); + + // Output node of child region propagate to Input node of child region out_wire_value(c, i, input_p, v) <-- tail_loop_io_node(c,tl,i, IO::Input), @@ -393,6 +364,11 @@ ascent::ascent! { in_wire_value(c, o, output_p, output_v), for (p, v) in tail_loop_worker(c, *tl, *output_p, 1, output_v); + lattice tail_loop_termination(ArcCtx<'a,H>,Node,OrdLattice); + tail_loop_termination(c,tl,TailLoopTermination::NeverTerminates.into()) <-- tail_loop_node(c,tl); + tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <-- tail_loop_node(c,tl), + tail_loop_io_node(c,tl,o, IO::Output), + in_wire_value(c, o, Into::::into(0usize), v); } impl<'a, H: HugrView> Dataflow<'a, H> { @@ -428,19 +404,23 @@ impl<'a, H: HugrView> Dataflow<'a, H> { .unwrap(); pv.try_into_value(&typ).ok() } + + pub fn tail_loop_terminates(&self, context: &Ctx<'a, H>, node: Node) -> TailLoopTermination { + assert!(context.get_optype(node).is_tail_loop()); + self.tail_loop_termination.iter().find_map(|(c,n,v)| (c.as_ref() == context && n == &node).then_some(v.0.clone())).unwrap() + } } #[cfg(test)] mod test { use hugr_core::{ - builder::{DFGBuilder, Dataflow, HugrBuilder, SubContainer}, - extension::{prelude::BOOL_T, EMPTY_REG}, - ops::{UnpackTuple, Value}, - type_row, - types::{FunctionType, SumType}, + builder::{Container, DFGBuilder, Dataflow, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, type_row, types::{FunctionType, SumType}, HugrView, OutgoingPort, Wire }; use hugr_core::partial_value::PartialValue; + use itertools::Itertools; + + use crate::const_fold2::datalog::TailLoopTermination; #[test] fn test_make_tuple() { @@ -510,7 +490,8 @@ mod test { let tlb = builder .tail_loop_builder([], [], vec![r_v.get_type()].into()) .unwrap(); - let [tl_o] = tlb.finish_with_outputs(r_w, []).unwrap().outputs_arr(); + let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = super::Dataflow::new(); @@ -520,6 +501,7 @@ mod test { let o_r = machine.read_out_wire_value(&c, tl_o).unwrap(); assert_eq!(o_r, r_v); + assert_eq!(TailLoopTermination::SingleIteration, machine.tail_loop_terminates(&c, tail_loop.node())) } #[test] @@ -530,7 +512,8 @@ mod test { let tlb = builder .tail_loop_builder([], [], vec![BOOL_T].into()) .unwrap(); - let [tl_o] = tlb.finish_with_outputs(r_w, []).unwrap().outputs_arr(); + let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = super::Dataflow::new(); @@ -539,54 +522,50 @@ mod test { // dbg!(&machine.out_wire_value); let o_r = machine.read_out_wire_partial_value(&c, tl_o).unwrap(); - assert_eq!(o_r, PartialValue::Top); + assert_eq!(o_r, PartialValue::Bottom); + assert_eq!(TailLoopTermination::NeverTerminates, machine.tail_loop_terminates(&c, tail_loop.node())) } -} - -// fn tc(hugr: &impl HugrView, node: Node) -> Vec<(Node, OutgoingPort, PartialValue)> { -// assert!(OpTag::DataflowParent.is_superset(hugr.get_optype(node).tag())); -// let d = DescendantsGraph::<'_, Node>::try_new(hugr, node).unwrap(); -// let mut cache = ValueCache::new(); - -// let singleton_in_row = |n: &Node, ip: &IncomingPort, v: &PartialValue| -> ValueRow { -// ValueRow::singleton_from_row(&hugr.signature(*n).unwrap().input, ip.index(), v.clone()) -// }; -// let top_row = |n: &Node| -> ValueRow { -// ValueRow::top_from_row(&hugr.signature(*n).unwrap().input) -// }; -// // ascent! { -// // relation node(Node) = d.nodes().map(|x| (x,)).collect_vec(); - -// // relation in_wire(Node, IncomingPort); -// // in_wire(n,p) <-- node(n), for p in d.node_inputs(*n); - -// // relation out_wire(Node, OutgoingPort); -// // out_wire(n,p) <-- node(n), for p in d.node_outputs(*n); + #[test] + fn test_tail_loop_iterates_twice() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); -// // lattice out_wire_value(Node, OutgoingPort, PartialValue); -// // out_wire_value(n,p, PartialValue::Top) <-- out_wire(n,p); + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); -// // lattice node_in_value_row(Node, ValueRow); -// // node_in_value_row(n, top_row(n)) <-- node(n); -// // node_in_value_row(n, singleton_in_row(n,ip,v)) <-- in_wire(n, ip), -// // if let Some((m,op)) = hugr.single_linked_output(*n, *ip), -// // out_wire_value(m, op, v); + // let r_w = builder + // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let tlb = builder + .tail_loop_builder([], [(BOOL_T,false_w), (BOOL_T,true_w)], vec![].into()) + .unwrap(); + assert_eq!(tlb.loop_signature().unwrap().dataflow_signature().unwrap(), FunctionType::new_endo(type_row![BOOL_T,BOOL_T])); + let [in_w1,in_w2] = tlb.input_wires_arr(); + let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); -// // lattice in_wire_value(Node, IncomingPort, PartialValue); -// // in_wire_value(n,p,v) <-- node_in_value_row(n, vr), for (p,v) in vr.iter(hugr,*n); + // let optype = builder.hugr().get_optype(tail_loop.node()); + // for p in builder.hugr().node_outputs(tail_loop.node()) { + // use hugr_core::ops::OpType; + // println!("{:?}, {:?}", p, optype.port_kind(p)); -// // relation load_constant_node(Node); -// // load_constant_node(n) <-- node(n), if hugr.get_optype(*n).is_load_constant(); + // } -// // out_wire_value(n, 0.into(), PartialValue::from_load_constant(&mut cache, hugr, *n)) <-- -// // load_constant_node(n); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + // TODO once we can do conditionals put these wires inside `just_outputs` and + // we should be able to propagate their values + // let [o_w1, o_w2, _] = tail_loop.outputs_arr(); -// // relation make_tuple_node(Node); -// // make_tuple_node(n) <-- node(n), if hugr.get_optype(*n).is_make_tuple(); + let mut machine = super::Dataflow::new(); + let c = machine.run_hugr(&hugr); + dbg!(&machine.tail_loop_io_node); + dbg!(&machine.out_wire_value); -// // out_wire_value(n,0.into(), PartialValue::tuple_from_value_row(vs)) <-- -// // make_tuple_node(n), node_in_value_row(n, vs); + // TODO these hould be the propagated values + // let o_r1 = machine.read_out_wire_value(&c, o_w1).unwrap(); + // assert_eq!(o_r1, Value::false_val()); + // let o_r2 = machine.read_out_wire_value(&c, o_w2).unwrap(); + // assert_eq!(o_r2, Value::true_val()); + assert_eq!(TailLoopTermination::Terminates, machine.tail_loop_terminates(&c, tail_loop.node())) + } +} -// // }.out_wire_value -// } From 6aca4ced24da3c8c5fb57fd1576d72d33d39ac90 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Jun 2024 11:07:51 +0100 Subject: [PATCH 09/12] wip --- hugr-core/src/partial_value.rs | 42 +- hugr-core/src/partial_value/test.rs | 54 +- hugr-core/src/partial_value/value_handle.rs | 50 +- hugr-passes/src/const_fold2/datalog.rs | 542 +++--------------- .../src/const_fold2/datalog/context.rs | 80 ++- hugr-passes/src/const_fold2/datalog/test.rs | 171 ++++++ hugr-passes/src/const_fold2/datalog/utils.rs | 281 +++++++++ 7 files changed, 685 insertions(+), 535 deletions(-) create mode 100644 hugr-passes/src/const_fold2/datalog/test.rs create mode 100644 hugr-passes/src/const_fold2/datalog/utils.rs diff --git a/hugr-core/src/partial_value.rs b/hugr-core/src/partial_value.rs index 3ab65f8c2..cf4e304f4 100644 --- a/hugr-core/src/partial_value.rs +++ b/hugr-core/src/partial_value.rs @@ -10,15 +10,16 @@ use crate::types::{Type, TypeEnum}; mod value_handle; -pub use value_handle::{ValueKey, ValueHandle}; - +pub use value_handle::{ValueHandle, ValueKey}; /// TODO shouldn't be pub #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(HashMap>); impl PartialSum { - pub fn unit() -> Self { Self::variant(0,[]) } + pub fn unit() -> Self { + Self::variant(0, []) + } pub fn variant(tag: usize, values: impl IntoIterator) -> Self { Self([(tag, values.into_iter().collect())].into_iter().collect()) } @@ -55,15 +56,14 @@ impl PartialSum { Err(self)? }; if v.len() != r.len() { - return Err(self) + return Err(self); } match zip_eq(v.into_iter(), r.into_iter()) .map(|(v, t)| v.clone().try_into_value(t)) - .collect::,_>>() { - Ok(vs) => { - Value::sum(*k, vs, st.clone()).map_err(|_| self) - } - Err(_) => Err(self) + .collect::, _>>() + { + Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), + Err(_) => Err(self), } } @@ -180,7 +180,7 @@ impl TryFrom for PartialSum { .collect(); return Ok(Self([(*tag, vec)].into_iter().collect())); } - _ => () + _ => (), }; Err(value) } @@ -206,7 +206,6 @@ impl From for PartialValue { } } - impl PartialValue { // const BOTTOM: Self = Self::Bottom; // const BOTTOM_REF: &'static Self = &Self::BOTTOM; @@ -231,7 +230,6 @@ impl PartialValue { } } - pub fn try_into_value(self, typ: &Type) -> Result { let r = match self { Self::Value(v) => Ok(v.value().clone()), @@ -303,11 +301,13 @@ impl PartialValue { } fn value_handles_equal(&self, rhs: &ValueHandle) -> bool { - let Self::Value(lhs) = self else { unreachable!() }; + let Self::Value(lhs) = self else { + unreachable!() + }; lhs == rhs - // The following is a good idea if ValueHandle gains an Eq - // instance and so does not do this check: - // || lhs.value() == rhs.value() + // The following is a good idea if ValueHandle gains an Eq + // instance and so does not do this check: + // || lhs.value() == rhs.value() } pub fn join(mut self, other: Self) -> Self { @@ -451,16 +451,14 @@ impl PartialValue { pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { match self { Self::Bottom => Self::Bottom, - Self::PartialSum(ps) => { - ps.variant_field_value(variant, idx) - } + Self::PartialSum(ps) => ps.variant_field_value(variant, idx), Self::Value(v) => { if v.tag() == variant { Self::Value(v.index(idx)) } else { Self::Bottom } - }, + } Self::Top => Self::Top, } } @@ -476,7 +474,9 @@ impl PartialOrd for PartialValue { (_, Self::Bottom) => Some(Ordering::Greater), (Self::Top, _) => Some(Ordering::Greater), (_, Self::Top) => Some(Ordering::Less), - (Self::Value(_), Self::Value(v2)) => self.value_handles_equal(v2).then_some(Ordering::Equal), + (Self::Value(_), Self::Value(v2)) => { + self.value_handles_equal(v2).then_some(Ordering::Equal) + } (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } diff --git a/hugr-core/src/partial_value/test.rs b/hugr-core/src/partial_value/test.rs index 427f7983b..f31c33642 100644 --- a/hugr-core/src/partial_value/test.rs +++ b/hugr-core/src/partial_value/test.rs @@ -5,7 +5,11 @@ use lazy_static::lazy_static; use proptest::prelude::*; use crate::{ - ops::Value, std_extensions::arithmetic::int_types::{self, get_log_width, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, types::{CustomType, Type, TypeEnum} + ops::Value, + std_extensions::arithmetic::int_types::{ + self, get_log_width, ConstInt, INT_TYPES, LOG_WIDTH_BOUND, + }, + types::{CustomType, Type, TypeEnum}, }; use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; @@ -37,7 +41,7 @@ impl TestSumLeafType { panic!("Expected int type, got {:#?}", t); } } - _ => () + _ => (), } } @@ -64,13 +68,17 @@ impl TestSumLeafType { fn partial_value_strategy(self) -> impl Strategy { match self { Self::Int(t) => { - let TypeEnum::Extension(ct) = t.as_type_enum() else { unreachable!() }; + let TypeEnum::Extension(ct) = t.as_type_enum() else { + unreachable!() + }; let lw = get_log_width(&ct.args()[0]).unwrap(); - (0u64..(1 << (2u64.pow(lw as u32) - 1))).prop_map(move |x| { - let ki = ConstInt::new_u(lw, x).unwrap(); - ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() - }).boxed() - }, + (0u64..(1 << (2u64.pow(lw as u32) - 1))) + .prop_map(move |x| { + let ki = ConstInt::new_u(lw, x).unwrap(); + ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() + }) + .boxed() + } Self::Unit => Just(PartialSum::unit().into()).boxed(), } } @@ -106,7 +114,7 @@ impl TestSumType { .map(|x| x.depth() + 1) .max() .unwrap_or(0); - Self::Branch(depth, vec.into()).into() + Self::Branch(depth, vec) } fn depth(&self) -> usize { @@ -136,7 +144,7 @@ impl TestSumType { } } - fn select(self) -> impl Strategy>)>> { + fn select(self) -> impl Strategy>)>> { match self { TestSumType::Branch(_, sop) => any::() .prop_map(move |i| { @@ -171,13 +179,13 @@ impl TestSumType { if prod.len() != v.len() { return false; } - if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(&rhs)) { + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { return false; } } true } - (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(&ps), + (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), } } } @@ -252,7 +260,11 @@ fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { let pvs = usts .into_iter() - .map(|x| any_partial_value_of_type(Arc::::unwrap_or_clone(x))) + .map(|x| { + any_partial_value_of_type( + Arc::::try_unwrap(x).unwrap_or_else(|x| x.as_ref().clone()), + ) + }) .collect_vec(); pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) .boxed() @@ -304,25 +316,27 @@ proptest! { #[test] fn bounded_lattice(v in any_partial_value()) { - prop_assert!(&v <= &PartialValue::Top); - prop_assert!(&v >= &PartialValue::Bottom); + prop_assert!(v <= PartialValue::Top); + prop_assert!(v >= PartialValue::Bottom); } #[test] - fn lattice_changed(v1 in any_partial_value()) { + fn meet_join_self_noop(v1 in any_partial_value()) { let mut subject = v1.clone(); assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); } #[test] fn lattice([v1,v2] in any_partial_values()) { let meet = v1.clone().meet(v2.clone()); - prop_assert!(&meet <= &v1, "meet not less <=: {:#?}", &meet); - prop_assert!(&meet <= &v2, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); let join = v1.clone().join(v2.clone()); - prop_assert!(&join >= &v1, "join not >=: {:#?}", &join); - prop_assert!(&join >= &v2, "join not >=: {:#?}", &join); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); } } diff --git a/hugr-core/src/partial_value/value_handle.rs b/hugr-core/src/partial_value/value_handle.rs index 7587ed2d8..dfb019872 100644 --- a/hugr-core/src/partial_value/value_handle.rs +++ b/hugr-core/src/partial_value/value_handle.rs @@ -1,7 +1,7 @@ use std::any::Any; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use std::hash::{DefaultHasher, Hash, Hasher}; use downcast_rs::Downcast; use itertools::Either; @@ -69,8 +69,8 @@ impl Hash for ValueKey { fn hash(&self, state: &mut H) { self.0.hash(state); match &self.1 { - Either::Left(n) => (0,n).hash(state), - Either::Right(v) => (1,v.hash()).hash(state), + Either::Left(n) => (0, n).hash(state), + Either::Right(v) => (1, v.hash()).hash(state), } } } @@ -82,7 +82,7 @@ impl From for ValueKey { } impl ValueKey { - pub fn new(k: impl ValueName) -> Self{ + pub fn new(k: impl ValueName) -> Self { Self(vec![], Either::Right(Arc::new(k))) } @@ -113,16 +113,24 @@ impl ValueHandle { } pub fn num_fields(&self) -> usize { - assert!(self.is_compound(), "ValueHandle::num_fields called on non-Sum, non-Tuple value: {:#?}", self); + assert!( + self.is_compound(), + "ValueHandle::num_fields called on non-Sum, non-Tuple value: {:#?}", + self + ); match self.value() { Value::Sum { values, .. } => values.len(), - | Value::Tuple { vs } => vs.len(), + Value::Tuple { vs } => vs.len(), _ => unreachable!(), } } pub fn tag(&self) -> usize { - assert!(self.is_compound(), "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", self); + assert!( + self.is_compound(), + "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", + self + ); match self.value() { Value::Sum { tag, .. } => *tag, Value::Tuple { .. } => 0, @@ -131,11 +139,16 @@ impl ValueHandle { } pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { - assert!(i < self.num_fields(), "ValueHandle::index called with out-of-bounds index {}: {:#?}", i, &self); + assert!( + i < self.num_fields(), + "ValueHandle::index called with out-of-bounds index {}: {:#?}", + i, + &self + ); let vs = match self.value() { Value::Sum { values, .. } => values, Value::Tuple { vs, .. } => vs, - _ => unreachable!() + _ => unreachable!(), }; let v = vs[i].clone().into(); Self(self.0.clone().index(i), v) @@ -196,23 +209,30 @@ mod test { let k5 = From::::from(portgraph::NodeIndex::new(1).into()); let k6 = From::::from(portgraph::NodeIndex::new(2).into()); - assert_eq!(&k4,&k5); - assert_ne!(&k4,&k6); + assert_eq!(&k4, &k5); + assert_ne!(&k4, &k6); let k7 = k5.clone().index(3); let k4 = k4.index(3); - assert_eq!(&k4,&k7); + assert_eq!(&k4, &k7); let k5 = k5.index(2); - assert_ne!(&k5,&k7); + assert_ne!(&k5, &k7); } #[test] fn value_handle_eq() { - let k_i = ConstInt::new_u(4,2).unwrap(); - let subject_val = Arc::new(Value::sum(0, [k_i.clone().into()], SumType::new([vec![k_i.get_type()], vec![]])).unwrap()); + let k_i = ConstInt::new_u(4, 2).unwrap(); + let subject_val = Arc::new( + Value::sum( + 0, + [k_i.clone().into()], + SumType::new([vec![k_i.get_type()], vec![]]), + ) + .unwrap(), + ); let k1 = ValueKey::new("foo".to_string()); let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 25c87b460..b633c7347 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,9 +1,9 @@ -use ascent::lattice::{BoundedLattice, Dual, Lattice, ord_lattice::OrdLattice}; +use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; use delegate::delegate; use itertools::{zip_eq, Itertools}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use either::Either; use hugr_core::ops::{OpTag, OpTrait, Value}; @@ -12,324 +12,69 @@ use hugr_core::types::{EdgeKind, FunctionType, SumType, Type, TypeEnum, TypeRow} use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; +mod utils; -use context::DataflowContext; - -#[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] -struct PV(PartialValue); - -impl From for PV { - fn from(inner: PartialValue) -> Self { - Self(inner) - } -} - -impl PV { - fn tuple_field_value(&self, idx: usize) -> Self { - self.0.tuple_field_value(idx).into() - } - - /// TODO the arguments here are not pretty, two usizes, better not mix them - /// up!!! - fn variant_field_value(&self, variant: usize, idx: usize) -> Self { - self.0.variant_field_value(variant, idx).into() - } - - fn supports_tag(&self, tag: usize) -> bool { - self.0.supports_tag(tag) - } -} - -impl From for PV { - fn from(inner: ValueHandle) -> Self { - Self(inner.into()) - } -} - -impl Lattice for PV { - fn meet(self, other: Self) -> Self { - self.0.meet(other.0).into() - } - - fn meet_mut(&mut self, other: Self) -> bool { - self.0.meet_mut(other.0) - } - - fn join(self, other: Self) -> Self { - self.0.join(other.0).into() - } - - fn join_mut(&mut self, other: Self) -> bool { - self.0.join_mut(other.0) - } -} - -impl BoundedLattice for PV { - fn bottom() -> Self { - PartialValue::bottom().into() - } - - fn top() -> Self { - PartialValue::top().into() - } -} - -#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] -struct ValueRow(Vec); - -impl ValueRow { - fn new(len: usize) -> Self { - Self(vec![PV::bottom(); len]) - } - - fn singleton(len: usize, idx: usize, v: PV) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r - } - - fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { - Self::singleton(r.len(), idx, v) - } - - fn bottom_from_row(r: &TypeRow) -> Self { - Self::new(r.len()) - } - - fn iter<'b, 'a, H: HugrView>( - &'b self, - context: &'b Ctx<'a, H>, - n: Node, - ) -> impl Iterator + 'b { - zip_eq(value_inputs(context, n), self.0.iter()) - } - - // fn initialised(&self) -> bool { - // self.0.iter().all(|x| x != &PV::top()) - // } -} - -impl Lattice for ValueRow { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - fn join_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.join_mut(v2); - } - changed - } - - fn meet_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.meet_mut(v2); - } - changed - } -} - -impl IntoIterator for ValueRow { - type Item = PV; - - type IntoIter = as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -type Ctx<'a, H> = DataflowContext<'a, H>; -type ArcCtx<'a, H> = Arc>; - -fn bottom_row<'a, H: HugrView>(context: &Ctx<'a, H>, n: Node) -> ValueRow { - if let Some(sig) = context.hugr().signature(n) { - ValueRow::new(sig.input_count()) - } else { - ValueRow::new(0) - } -} - -fn singleton_in_row<'a, H: HugrView>( - context: &Ctx<'a, H>, - n: &Node, - ip: &IncomingPort, - v: PV, -) -> ValueRow { - let Some(sig) = context.hugr().signature(*n) else { - panic!("dougrulz"); - }; - if sig.input_count() <= ip.index() { - panic!( - "bad port index: {} >= {}: {}", - ip.index(), - sig.input_count(), - context.hugr().get_optype(*n).description() - ); - } - ValueRow::singleton_from_row(&context.hugr().signature(*n).unwrap().input, ip.index(), v) -} - -fn partial_value_from_load_constant<'a, H: HugrView>( - context: &context::DataflowContext<'a, H>, - node: Node, -) -> PV { - let load_op = context.hugr().get_optype(node).as_load_constant().unwrap(); - let const_node = context - .hugr() - .single_linked_output(node, load_op.constant_port()) - .unwrap() - .0; - let const_op = context.hugr().get_optype(const_node).as_const().unwrap(); - context - .node_value_handle(const_node, const_op.value()) - .into() -} - -fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { - PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum IO { - Input, - Output, -} - -fn value_inputs<'a, H: HugrView>( - context: &Ctx<'a, H>, - n: Node, -) -> impl Iterator + 'a { - context.hugr().in_value_types(n).map(|x| x.0) -} - -fn value_outputs<'a, H: HugrView>( - context: &Ctx<'a, H>, - n: Node, -) -> impl Iterator + 'a { - context.hugr().out_value_types(n).map(|x| x.0) -} - -// todo this should work for dataflowblocks too -fn tail_loop_worker<'b, 'a, H: HugrView>( - context: &Ctx<'a, H>, - n: Node, - output_p: IncomingPort, - control_variant: usize, - v: &'b PV, -) -> impl Iterator + 'b where 'a: 'b { - let tail_loop_op = context.get_optype(n).as_tail_loop().unwrap(); - let num_variant_vals = if control_variant == 0 { - tail_loop_op.just_inputs.len() - } else { - tail_loop_op.just_outputs.len() - }; - let hugr = context.hugr(); - if output_p.index() == 0 { - Either::Left( - (0..num_variant_vals) - .map(move |i| (i.into(), v.variant_field_value(control_variant, i))), - ) - } else { - let v = if v.supports_tag(control_variant) { - v.clone() - } else { - PV::bottom() - }; - Either::Right(std::iter::once(((num_variant_vals + output_p.index() - 1).into(), v))) - }.inspect(move |x| assert!(matches!(hugr.get_optype(n).port_kind(x.0), Some(EdgeKind::Value(_))))) -} - -#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone)] -pub enum TailLoopTermination { - NeverTerminates, - SingleIteration, - Terminates, -} - -impl TailLoopTermination { - fn from_control_value(v: &PV) -> Self { - if v.supports_tag(1) && !v.supports_tag(0) { - Self::SingleIteration - } else if v.supports_tag(1) { - Self::Terminates - } else { - Self::NeverTerminates - } - } -} - -impl From for OrdLattice { - fn from(value: TailLoopTermination) -> Self { - Self(value) - } -} +pub use context::{ArcDataflowContext, DFContext, ValueCache}; +pub use utils::{TailLoopTermination, ValueRow, IO, PV}; ascent::ascent! { - struct Dataflow<'a, H: HugrView>; - relation context(ArcCtx<'a, H>); - relation node(ArcCtx<'a, H>, Node); - relation in_wire(ArcCtx<'a,H>, Node, IncomingPort); - relation out_wire(ArcCtx<'a,H>, Node, OutgoingPort); - lattice out_wire_value(ArcCtx<'a,H>, Node, OutgoingPort, PV); - lattice node_in_value_row(ArcCtx<'a,H>, Node, ValueRow); - lattice in_wire_value(ArcCtx<'a,H>, Node, IncomingPort, PV); + struct AscentProgram; + relation context(C); + relation node(C, Node); + relation in_wire(C, Node, IncomingPort); + relation out_wire(C, Node, OutgoingPort); + relation parent_of_node(C, Node, Node); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); + + node(c, n) <-- context(c), for n in c.hugr().nodes(); - node(c, n) <-- context(c), for n in c.nodes(); + in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c, *n); - in_wire(c, n,p) <-- node(c, n), for p in value_inputs(c, *n); + out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c, *n); - out_wire(c, n,p) <-- node(c, n), for p in value_outputs(c, *n); + parent_of_node(c, parent, child) <-- + node(c, child), if let Some(parent) = c.hugr().get_parent(*child); // All out wire values are initialised to Bottom. If any value is Bottom after // running we can infer that execution never reaches that value. out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), - if let Some((m,op)) = c.single_linked_output(*n, *ip), + if let Some((m,op)) = c.hugr().single_linked_output(*n, *ip), out_wire_value(c, m, op, v); - node_in_value_row(c, n, bottom_row(c, *n)) <-- node(c, n); - node_in_value_row(c, n, singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n); + node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); // LoadConstant - relation load_constant_node(ArcCtx<'a, H>, Node); - load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); + relation load_constant_node(C, Node); + load_constant_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_load_constant(); - out_wire_value(c, n, 0.into(), partial_value_from_load_constant(c, *n)) <-- + out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <-- load_constant_node(c, n); // MakeTuple - relation make_tuple_node(ArcCtx<'a,H>, Node); - make_tuple_node(c, n) <-- node(c, n), if c.get_optype(*n).is_make_tuple(); + relation make_tuple_node(C, Node); + make_tuple_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_make_tuple(); - out_wire_value(c, n, 0.into(), partial_value_tuple_from_value_row(vs.clone())) <-- + out_wire_value(c, n, 0.into(), utils::partial_value_tuple_from_value_row(vs.clone())) <-- make_tuple_node(c, n), node_in_value_row(c, n, vs); // UnpackTuple - relation unpack_tuple_node(ArcCtx<'a, H>, Node); - unpack_tuple_node(c,n) <-- node(c, n), if c.get_optype(*n).is_unpack_tuple(); + relation unpack_tuple_node(C, Node); + unpack_tuple_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_unpack_tuple(); out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- unpack_tuple_node(c, n), in_wire_value(c, n, IncomingPort::from(0), v), out_wire(c, n, p); // DFG - relation dfg_node(ArcCtx<'a, H>, Node); - dfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_dfg(); - relation dfg_io_node(ArcCtx<'a, H>, Node, Node, IO); + relation dfg_node(C, Node); + dfg_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_dfg(); + relation dfg_io_node(C, Node, Node, IO); dfg_io_node(c,dfg,n,io) <-- dfg_node(c,dfg), - if let Some([i,o]) = c.get_io(*dfg), + if let Some([i,o]) = c.hugr().get_io(*dfg), for (n, io) in [(i, IO::Input), (o, IO::Output)]; out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- @@ -339,14 +84,13 @@ ascent::ascent! { // TailLoop - relation tail_loop_node(ArcCtx<'a, H>, Node); - tail_loop_node(c,n) <-- node(c, n), if c.get_optype(*n).is_tail_loop(); - relation tail_loop_io_node(ArcCtx<'a, H>, Node, Node, IO); + relation tail_loop_node(C, Node); + tail_loop_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_tail_loop(); + relation tail_loop_io_node(C, Node, Node, IO); tail_loop_io_node(c,tl,n, io) <-- tail_loop_node(c,tl), - if let Some([i,o]) = c.get_io(*tl), + if let Some([i,o]) = c.hugr().get_io(*tl), for (n,io) in [(i,IO::Input), (o, IO::Output)]; - // inputs of tail loop propagate to Input node of child region out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); @@ -357,43 +101,55 @@ ascent::ascent! { tail_loop_io_node(c,tl,i, IO::Input), tail_loop_io_node(c,tl,o, IO::Output), in_wire_value(c, o, output_p, output_v), - for (input_p, v) in tail_loop_worker(c, *tl, *output_p, 0, output_v); + for (input_p, v) in utils::tail_loop_worker(c, *tl, *output_p, 0, output_v); // Output node of child region propagate to outputs of tail loop out_wire_value(c, tl, p, v) <-- tail_loop_io_node(c,tl,o, IO::Output), in_wire_value(c, o, output_p, output_v), - for (p, v) in tail_loop_worker(c, *tl, *output_p, 1, output_v); + for (p, v) in utils::tail_loop_worker(c, *tl, *output_p, 1, output_v); - lattice tail_loop_termination(ArcCtx<'a,H>,Node,OrdLattice); + lattice tail_loop_termination(C,Node,OrdLattice); tail_loop_termination(c,tl,TailLoopTermination::NeverTerminates.into()) <-- tail_loop_node(c,tl); tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <-- tail_loop_node(c,tl), tail_loop_io_node(c,tl,o, IO::Output), in_wire_value(c, o, Into::::into(0usize), v); } -impl<'a, H: HugrView> Dataflow<'a, H> { +struct Machine<'a, H: HugrView> { + program: AscentProgram>, + cache: Arc>, +} + +impl<'a, H: HugrView> Machine<'a, H> { pub fn new() -> Self { - Self::default() + Self { + program: Default::default(), + cache: ValueCache::new(), + } } - pub fn run_hugr(&mut self, hugr: &'a H) -> ArcCtx<'a, H> { - let context = context::DataflowContext::new(hugr); - self.context.push((context.clone(),)); - self.run(); + pub fn run_hugr(&mut self, hugr: &'a H) -> ArcDataflowContext<'a, H> { + let context = ArcDataflowContext::new(hugr, self.cache.clone()); + self.program.context.push((context.clone(),)); + self.program.run(); context } pub fn read_out_wire_partial_value( &self, - context: &Ctx<'a, H>, + context: &ArcDataflowContext<'a, H>, w: Wire, ) -> Option { - self.out_wire_value.iter().find_map(|(c, n, p, v)| { - (c.as_ref() == context && &w.node() == n && &w.source() == p).then(|| v.clone().0) + self.program.out_wire_value.iter().find_map(|(c, n, p, v)| { + (c == context && &w.node() == n && &w.source() == p).then(|| v.clone().into()) }) } - pub fn read_out_wire_value(&self, context: &Ctx<'a, H>, w: Wire) -> Option { + pub fn read_out_wire_value( + &self, + context: &ArcDataflowContext<'a, H>, + w: Wire, + ) -> Option { // dbg!(&w); let pv = self.read_out_wire_partial_value(context, w)?; // dbg!(&pv); @@ -405,167 +161,19 @@ impl<'a, H: HugrView> Dataflow<'a, H> { pv.try_into_value(&typ).ok() } - pub fn tail_loop_terminates(&self, context: &Ctx<'a, H>, node: Node) -> TailLoopTermination { + pub fn tail_loop_terminates( + &self, + context: &ArcDataflowContext<'a, H>, + node: Node, + ) -> TailLoopTermination { assert!(context.get_optype(node).is_tail_loop()); - self.tail_loop_termination.iter().find_map(|(c,n,v)| (c.as_ref() == context && n == &node).then_some(v.0.clone())).unwrap() - } -} - -#[cfg(test)] -mod test { - use hugr_core::{ - builder::{Container, DFGBuilder, Dataflow, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, type_row, types::{FunctionType, SumType}, HugrView, OutgoingPort, Wire - }; - - use hugr_core::partial_value::PartialValue; - use itertools::Itertools; - - use crate::const_fold2::datalog::TailLoopTermination; - - #[test] - fn test_make_tuple() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); - let v1 = builder.add_load_value(Value::false_val()); - let v2 = builder.add_load_value(Value::true_val()); - let v3 = builder.make_tuple([v1, v2]).unwrap(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - - let mut machine = super::Dataflow::new(); - let c = machine.run_hugr(&hugr); - - let x = machine.read_out_wire_value(&c, v3).unwrap(); - assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); - } - - #[test] - fn test_unpack_tuple() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); - let v1 = builder.add_load_value(Value::false_val()); - let v2 = builder.add_load_value(Value::true_val()); - let v3 = builder.make_tuple([v1, v2]).unwrap(); - let [o1, o2] = builder - .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v3]) - .unwrap() - .outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - - let mut machine = super::Dataflow::new(); - let c = machine.run_hugr(&hugr); - - let o1_r = machine.read_out_wire_value(&c, o1).unwrap(); - assert_eq!(o1_r, Value::false_val()); - let o2_r = machine.read_out_wire_value(&c, o2).unwrap(); - assert_eq!(o2_r, Value::true_val()); - } - - #[test] - fn test_unpack_const() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); - let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); - let [o] = builder - .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) + self.program + .tail_loop_termination + .iter() + .find_map(|(c, n, v)| (c == context && n == &node).then_some(v.0.clone())) .unwrap() - .outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - - let mut machine = super::Dataflow::new(); - let c = machine.run_hugr(&hugr); - - let o_r = machine.read_out_wire_value(&c, o).unwrap(); - assert_eq!(o_r, Value::true_val()); - } - - #[test] - fn test_tail_loop_never_iterates() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); - let r_v = Value::unit_sum(3, 6).unwrap(); - let r_w = builder.add_load_value( - Value::sum( - 1, - [r_v.clone()], - SumType::new([type_row![], r_v.get_type().into()]), - ) - .unwrap(), - ); - let tlb = builder - .tail_loop_builder([], [], vec![r_v.get_type()].into()) - .unwrap(); - let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); - let [tl_o] = tail_loop.outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - - let mut machine = super::Dataflow::new(); - let c = machine.run_hugr(&hugr); - // dbg!(&machine.tail_loop_io_node); - // dbg!(&machine.out_wire_value); - - let o_r = machine.read_out_wire_value(&c, tl_o).unwrap(); - assert_eq!(o_r, r_v); - assert_eq!(TailLoopTermination::SingleIteration, machine.tail_loop_terminates(&c, tail_loop.node())) - } - - #[test] - fn test_tail_loop_always_iterates() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); - let r_w = builder - .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); - let tlb = builder - .tail_loop_builder([], [], vec![BOOL_T].into()) - .unwrap(); - let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); - let [tl_o] = tail_loop.outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - - let mut machine = super::Dataflow::new(); - let c = machine.run_hugr(&hugr); - // dbg!(&machine.tail_loop_io_node); - // dbg!(&machine.out_wire_value); - - let o_r = machine.read_out_wire_partial_value(&c, tl_o).unwrap(); - assert_eq!(o_r, PartialValue::Bottom); - assert_eq!(TailLoopTermination::NeverTerminates, machine.tail_loop_terminates(&c, tail_loop.node())) - } - - #[test] - fn test_tail_loop_iterates_twice() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); - // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); - - let true_w = builder.add_load_value(Value::true_val()); - let false_w = builder.add_load_value(Value::false_val()); - - // let r_w = builder - // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); - let tlb = builder - .tail_loop_builder([], [(BOOL_T,false_w), (BOOL_T,true_w)], vec![].into()) - .unwrap(); - assert_eq!(tlb.loop_signature().unwrap().dataflow_signature().unwrap(), FunctionType::new_endo(type_row![BOOL_T,BOOL_T])); - let [in_w1,in_w2] = tlb.input_wires_arr(); - let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); - - // let optype = builder.hugr().get_optype(tail_loop.node()); - // for p in builder.hugr().node_outputs(tail_loop.node()) { - // use hugr_core::ops::OpType; - // println!("{:?}, {:?}", p, optype.port_kind(p)); - - // } - - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - // TODO once we can do conditionals put these wires inside `just_outputs` and - // we should be able to propagate their values - // let [o_w1, o_w2, _] = tail_loop.outputs_arr(); - - let mut machine = super::Dataflow::new(); - let c = machine.run_hugr(&hugr); - dbg!(&machine.tail_loop_io_node); - dbg!(&machine.out_wire_value); - - // TODO these hould be the propagated values - // let o_r1 = machine.read_out_wire_value(&c, o_w1).unwrap(); - // assert_eq!(o_r1, Value::false_val()); - // let o_r2 = machine.read_out_wire_value(&c, o_w2).unwrap(); - // assert_eq!(o_r2, Value::true_val()); - assert_eq!(TailLoopTermination::Terminates, machine.tail_loop_terminates(&c, tail_loop.node())) } } +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 4d981d0db..1661fb082 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -1,20 +1,22 @@ -use std::cell::RefCell; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::atomic::AtomicUsize; -use std::sync::Arc; - +use std::sync::{Arc, Mutex}; use hugr_core::ops::Value; use hugr_core::partial_value::{ValueHandle, ValueKey}; -use hugr_core::Node; +use hugr_core::{HugrView, Node}; #[derive(Clone)] pub struct ValueCache(HashMap>); impl ValueCache { - fn new() -> Self { + pub fn new() -> Arc> { + Arc::new(Mutex::new(Self::new_bare())) + } + + fn new_bare() -> Self { Self(HashMap::new()) } @@ -37,20 +39,21 @@ fn next_context_id() -> usize { pub struct DataflowContext<'a, H> { id: usize, hugr: &'a H, - cache: RefCell, + cache: Arc>, } impl<'a, H> DataflowContext<'a, H> { - pub fn new(hugr: &'a H) -> Arc { - Arc::new(Self { + fn new(hugr: &'a H, cache: Arc>) -> Self { + Self { id: next_context_id(), hugr, - cache: ValueCache::new().into(), - }) + cache, + } } - pub fn node_value_handle(&self, node: Node, value: &Value) -> ValueHandle { - self.cache.borrow_mut().get(node.into(), value) + pub fn get_value_handle(&self, key: impl Into, value: &Value) -> ValueHandle { + let mut guard = self.cache.lock().unwrap(); + guard.get(key.into(), value) } pub fn hugr(&self) -> &'a H { @@ -101,3 +104,56 @@ impl<'a, H> Deref for DataflowContext<'a, H> { self.hugr } } + +pub struct ArcDataflowContext<'a, H>(Arc>); + +impl<'a, H> ArcDataflowContext<'a, H> { + pub fn new(h: &'a H, cache: Arc>) -> Self { + Self(Arc::new(DataflowContext::new(h, cache))) + } +} + +impl<'a, H> Clone for ArcDataflowContext<'a, H> { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl<'a, H> Hash for ArcDataflowContext<'a, H> { + fn hash(&self, state: &mut HA) { + self.0.hash(state); + } +} + +impl<'a, H> PartialEq for ArcDataflowContext<'a, H> { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl<'a, H> Eq for ArcDataflowContext<'a, H> {} + +impl<'a, H> Deref for ArcDataflowContext<'a, H> { + type Target = DataflowContext<'a, H>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub trait DFContext: Clone + Eq + Hash { + type H: HugrView; + fn hugr(&self) -> &Self::H; + fn node_value_handle(&self, const_node: Node, value: &Value) -> ValueHandle; +} + +impl<'a, H: HugrView> DFContext for ArcDataflowContext<'a, H> { + type H = H; + fn hugr(&self) -> &Self::H { + self.0.hugr + } + + fn node_value_handle(&self, const_node: Node, value: &Value) -> ValueHandle { + self.0.get_value_handle(const_node, value) + } +} diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs new file mode 100644 index 000000000..31dd7f55f --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -0,0 +1,171 @@ +use hugr_core::{ + builder::{Container, DFGBuilder, Dataflow, HugrBuilder, SubContainer}, + extension::{prelude::BOOL_T, EMPTY_REG}, + ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, + type_row, + types::{FunctionType, SumType}, + HugrView, OutgoingPort, Wire, +}; + +use hugr_core::partial_value::PartialValue; +use itertools::Itertools; + +use super::*; + +#[test] +fn test_make_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let x = machine.read_out_wire_value(&c, v3).unwrap(); + assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); +} + +#[test] +fn test_unpack_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let [o1, o2] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v3]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let o1_r = machine.read_out_wire_value(&c, o1).unwrap(); + assert_eq!(o1_r, Value::false_val()); + let o2_r = machine.read_out_wire_value(&c, o2).unwrap(); + assert_eq!(o2_r, Value::true_val()); +} + +#[test] +fn test_unpack_const() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); + let [o] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let o_r = machine.read_out_wire_value(&c, o).unwrap(); + assert_eq!(o_r, Value::true_val()); +} + +#[test] +fn test_tail_loop_never_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_v = Value::unit_sum(3, 6).unwrap(); + let r_w = builder.add_load_value( + Value::sum( + 1, + [r_v.clone()], + SumType::new([type_row![], r_v.get_type().into()]), + ) + .unwrap(), + ); + let tlb = builder + .tail_loop_builder([], [], vec![r_v.get_type()].into()) + .unwrap(); + let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + let o_r = machine.read_out_wire_value(&c, tl_o).unwrap(); + assert_eq!(o_r, r_v); + assert_eq!( + TailLoopTermination::SingleIteration, + machine.tail_loop_terminates(&c, tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_always_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_w = builder + .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let tlb = builder + .tail_loop_builder([], [], vec![BOOL_T].into()) + .unwrap(); + let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + let o_r = machine.read_out_wire_partial_value(&c, tl_o).unwrap(); + assert_eq!(o_r, PartialValue::Bottom); + assert_eq!( + TailLoopTermination::NeverTerminates, + machine.tail_loop_terminates(&c, tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_iterates_twice() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + // let r_w = builder + // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let tlb = builder + .tail_loop_builder([], [(BOOL_T, false_w), (BOOL_T, true_w)], vec![].into()) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().dataflow_signature().unwrap(), + FunctionType::new_endo(type_row![BOOL_T, BOOL_T]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); + + // let optype = builder.hugr().get_optype(tail_loop.node()); + // for p in builder.hugr().node_outputs(tail_loop.node()) { + // use hugr_core::ops::OpType; + // println!("{:?}, {:?}", p, optype.port_kind(p)); + + // } + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + // TODO once we can do conditionals put these wires inside `just_outputs` and + // we should be able to propagate their values + // let [o_w1, o_w2, _] = tail_loop.outputs_arr(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + // TODO these hould be the propagated values + // let o_r1 = machine.read_out_wire_value(&c, o_w1).unwrap(); + // assert_eq!(o_r1, Value::false_val()); + // let o_r2 = machine.read_out_wire_value(&c, o_w2).unwrap(); + // assert_eq!(o_r2, Value::true_val()); + assert_eq!( + TailLoopTermination::Terminates, + machine.tail_loop_terminates(&c, tail_loop.node()) + ) +} diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs new file mode 100644 index 000000000..a5de247fa --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -0,0 +1,281 @@ +use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; +use either::Either; +use hugr_core::{ + ops::OpTrait as _, + partial_value::{PartialValue, ValueHandle}, + types::{EdgeKind, TypeRow}, + HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, +}; +use itertools::zip_eq; + +use super::context::DFContext; + +#[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] +pub struct PV(PartialValue); + +impl From for PV { + fn from(inner: PartialValue) -> Self { + Self(inner) + } +} + +impl PV { + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO the arguments here are not pretty, two usizes, better not mix them + /// up!!! + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + self.0.variant_field_value(variant, idx).into() + } + + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.supports_tag(tag) + } +} + +impl From for PartialValue { + fn from(value: PV) -> Self { + value.0 + } +} + +impl From for PV { + fn from(inner: ValueHandle) -> Self { + Self(inner.into()) + } +} + +impl Lattice for PV { + fn meet(self, other: Self) -> Self { + self.0.meet(other.0).into() + } + + fn meet_mut(&mut self, other: Self) -> bool { + self.0.meet_mut(other.0) + } + + fn join(self, other: Self) -> Self { + self.0.join(other.0).into() + } + + fn join_mut(&mut self, other: Self) -> bool { + self.0.join_mut(other.0) + } +} + +impl BoundedLattice for PV { + fn bottom() -> Self { + PartialValue::bottom().into() + } + + fn top() -> Self { + PartialValue::top().into() + } +} + +#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] +pub struct ValueRow(Vec); + +impl ValueRow { + fn new(len: usize) -> Self { + Self(vec![PV::bottom(); len]) + } + + fn singleton(len: usize, idx: usize, v: PV) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { + Self::singleton(r.len(), idx, v) + } + + fn bottom_from_row(r: &TypeRow) -> Self { + Self::new(r.len()) + } + + fn iter<'b>( + &'b self, + context: &'b impl DFContext, + n: Node, + ) -> impl Iterator + 'b { + zip_eq(value_inputs(context, n), self.0.iter()) + } + + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } +} + +impl Lattice for ValueRow { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PV; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +pub(super) fn bottom_row(context: &impl DFContext, n: Node) -> ValueRow { + if let Some(sig) = context.hugr().signature(n) { + ValueRow::new(sig.input_count()) + } else { + ValueRow::new(0) + } +} + +pub(super) fn singleton_in_row( + context: &impl DFContext, + n: &Node, + ip: &IncomingPort, + v: PV, +) -> ValueRow { + let Some(sig) = context.hugr().signature(*n) else { + panic!("dougrulz"); + }; + if sig.input_count() <= ip.index() { + panic!( + "bad port index: {} >= {}: {}", + ip.index(), + sig.input_count(), + context.hugr().get_optype(*n).description() + ); + } + ValueRow::singleton_from_row(&context.hugr().signature(*n).unwrap().input, ip.index(), v) +} + +pub(super) fn partial_value_from_load_constant(context: &impl DFContext, node: Node) -> PV { + let load_op = context.hugr().get_optype(node).as_load_constant().unwrap(); + let const_node = context + .hugr() + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = context.hugr().get_optype(const_node).as_const().unwrap(); + context + .node_value_handle(const_node, const_op.value()) + .into() +} + +pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { + PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum IO { + Input, + Output, +} + +pub(super) fn value_inputs( + context: &impl DFContext, + n: Node, +) -> impl Iterator + '_ { + context.hugr().in_value_types(n).map(|x| x.0) +} + +pub(super) fn value_outputs( + context: &impl DFContext, + n: Node, +) -> impl Iterator + '_ { + context.hugr().out_value_types(n).map(|x| x.0) +} + +// todo this should work for dataflowblocks too +pub(super) fn tail_loop_worker<'a>( + context: &'a impl DFContext, + n: Node, + output_p: IncomingPort, + control_variant: usize, + v: &'a PV, +) -> impl Iterator + 'a { + let tail_loop_op = context.hugr().get_optype(n).as_tail_loop().unwrap(); + let num_variant_vals = if control_variant == 0 { + tail_loop_op.just_inputs.len() + } else { + tail_loop_op.just_outputs.len() + }; + let hugr = context.hugr(); + if output_p.index() == 0 { + Either::Left( + (0..num_variant_vals) + .map(move |i| (i.into(), v.variant_field_value(control_variant, i))), + ) + } else { + let v = if v.supports_tag(control_variant) { + v.clone() + } else { + PV::bottom() + }; + Either::Right(std::iter::once(( + (num_variant_vals + output_p.index() - 1).into(), + v, + ))) + } + .inspect(move |x| { + assert!(matches!( + hugr.get_optype(n).port_kind(x.0), + Some(EdgeKind::Value(_)) + )) + }) +} + +#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone)] +pub enum TailLoopTermination { + NeverTerminates, + SingleIteration, + Terminates, +} + +impl TailLoopTermination { + pub fn from_control_value(v: &PV) -> Self { + if v.supports_tag(1) && !v.supports_tag(0) { + Self::SingleIteration + } else if v.supports_tag(1) { + Self::Terminates + } else { + Self::NeverTerminates + } + } +} + +impl From for OrdLattice { + fn from(value: TailLoopTermination) -> Self { + Self(value) + } +} From 81b49a3e1595d687d011db52ebc71b305f5d3689 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Jun 2024 12:21:02 +0100 Subject: [PATCH 10/12] conditional --- hugr-passes/src/const_fold2/datalog.rs | 70 ++++++++++++++++++-- hugr-passes/src/const_fold2/datalog/test.rs | 51 ++++++++++++-- hugr-passes/src/const_fold2/datalog/utils.rs | 27 ++------ 3 files changed, 116 insertions(+), 32 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index b633c7347..fae2fa508 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -24,6 +24,7 @@ ascent::ascent! { relation in_wire(C, Node, IncomingPort); relation out_wire(C, Node, OutgoingPort); relation parent_of_node(C, Node, Node); + relation out_wire_value_proto(Node, OutgoingPort, PV); lattice out_wire_value(C, Node, OutgoingPort, PV); lattice node_in_value_row(C, Node, ValueRow); lattice in_wire_value(C, Node, IncomingPort, PV); @@ -40,6 +41,8 @@ ascent::ascent! { // All out wire values are initialised to Bottom. If any value is Bottom after // running we can infer that execution never reaches that value. out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); + out_wire_value(c, n, p, v) <-- out_wire(c,n,p) , out_wire_value_proto(n, p, v); + in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), if let Some((m,op)) = c.hugr().single_linked_output(*n, *ip), @@ -67,7 +70,10 @@ ascent::ascent! { relation unpack_tuple_node(C, Node); unpack_tuple_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_unpack_tuple(); - out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- unpack_tuple_node(c, n), in_wire_value(c, n, IncomingPort::from(0), v), out_wire(c, n, p); + out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- + unpack_tuple_node(c, n), + in_wire_value(c, n, IncomingPort::from(0), v), + out_wire(c, n, p); // DFG relation dfg_node(C, Node); @@ -101,18 +107,59 @@ ascent::ascent! { tail_loop_io_node(c,tl,i, IO::Input), tail_loop_io_node(c,tl,o, IO::Output), in_wire_value(c, o, output_p, output_v), - for (input_p, v) in utils::tail_loop_worker(c, *tl, *output_p, 0, output_v); + if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(), + let variant_len = tailloop.just_inputs.len(), + for (input_p, v) in utils::tail_loop_worker(*output_p, 0, variant_len, output_v); + // Output node of child region propagate to outputs of tail loop out_wire_value(c, tl, p, v) <-- tail_loop_io_node(c,tl,o, IO::Output), in_wire_value(c, o, output_p, output_v), - for (p, v) in utils::tail_loop_worker(c, *tl, *output_p, 1, output_v); + if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(), + let variant_len = tailloop.just_outputs.len(), + for (p, v) in utils::tail_loop_worker(*output_p, 1, variant_len, output_v); lattice tail_loop_termination(C,Node,OrdLattice); - tail_loop_termination(c,tl,TailLoopTermination::NeverTerminates.into()) <-- tail_loop_node(c,tl); - tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <-- tail_loop_node(c,tl), + tail_loop_termination(c,tl,TailLoopTermination::NeverTerminates.into()) <-- + tail_loop_node(c,tl); + tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <-- + tail_loop_node(c,tl), tail_loop_io_node(c,tl,o, IO::Output), in_wire_value(c, o, Into::::into(0usize), v); + + // Conditional + relation conditional_node(C, Node); + relation case_node(C,Node,usize, Node); + relation case_io_node(C, Node, Node, IO); + + conditional_node (c,n)<-- node(c, n), if c.hugr().get_optype(*n).is_conditional(); + case_node(c,cond,i, case) <-- conditional_node(c,cond), + for (i, case) in c.hugr().children(*cond).enumerate(), + if c.hugr().get_optype(case).is_case(); + case_io_node(c,case, n, io) <-- case_node(c, _, _, case), + if let Some([i,o]) = c.hugr().get_io(*case), + for (n,io) in [(i,IO::Input), (o, IO::Output)]; + + // inputs of conditional propagate into case nodes + out_wire_value(c, i_node, i_p, v) <-- + case_node(c, cond, case_index, case), + case_io_node(c, case, i_node, IO::Input), + in_wire_value(c, cond, cond_in_p, cond_in_v), + if let Some(conditional) = c.hugr().get_optype(*cond).as_conditional(), + let variant_len = conditional.sum_rows[*case_index].len(), + for (i_p, v) in utils::tail_loop_worker(*cond_in_p, *case_index, variant_len, cond_in_v); + + // outputs of case nodes propagate to outputs of conditional + out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(c, cond, _, case), + case_io_node(c, case, o, IO::Output), + in_wire_value(c, o, o_p, v); + + lattice case_reachable(C, Node, Node, bool); + case_reachable(c, cond, case, reachable) <-- case_node(c,cond,i,case), + in_wire_value(c, cond, IncomingPort::from(0), v), + let reachable = v.supports_tag(*i); + } struct Machine<'a, H: HugrView> { @@ -128,6 +175,10 @@ impl<'a, H: HugrView> Machine<'a, H> { } } + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + self.program.out_wire_value_proto.extend(wires.into_iter().map(|(w,v)| (w.node(), w.source(), v.into()))); + } + pub fn run_hugr(&mut self, hugr: &'a H) -> ArcDataflowContext<'a, H> { let context = ArcDataflowContext::new(hugr, self.cache.clone()); self.program.context.push((context.clone(),)); @@ -173,6 +224,15 @@ impl<'a, H: HugrView> Machine<'a, H> { .find_map(|(c, n, v)| (c == context && n == &node).then_some(v.0.clone())) .unwrap() } + + pub fn case_reachable(&self, + context: &ArcDataflowContext<'a, H>, + case: Node,) -> bool { + assert!(context.get_optype(case).is_case()); + let cond = context.hugr().get_parent(case).unwrap(); + assert!(context.get_optype(cond).is_conditional()); + self.program.case_reachable.iter().find_map(|(c,cond2,case2,i)| (c == context && &cond == cond2 && &case == case2).then_some(*i)).unwrap() + } } #[cfg(test)] diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 31dd7f55f..563c83130 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -1,14 +1,8 @@ use hugr_core::{ - builder::{Container, DFGBuilder, Dataflow, HugrBuilder, SubContainer}, - extension::{prelude::BOOL_T, EMPTY_REG}, - ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, - type_row, - types::{FunctionType, SumType}, - HugrView, OutgoingPort, Wire, + builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, partial_value::PartialSum, type_row, types::{FunctionType, SumType}, Extension }; use hugr_core::partial_value::PartialValue; -use itertools::Itertools; use super::*; @@ -169,3 +163,46 @@ fn test_tail_loop_iterates_twice() { machine.tail_loop_terminates(&c, tail_loop.node()) ) } + +#[test] +fn conditional() { + let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; + let cond_t = Type::new_sum(variants.clone()); + let mut builder = DFGBuilder::new(FunctionType::new(Into::::into(cond_t),type_row![])).unwrap(); + let [arg_w] = builder.input_wires_arr(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut cond_builder = builder.conditional_builder((variants, arg_w), [(BOOL_T,true_w)], type_row!(BOOL_T,BOOL_T), ExtensionSet::default()).unwrap(); + // will be unreachable + let case1_b = cond_builder.case_builder(0).unwrap(); + let case1 = case1_b.finish_with_outputs([false_w,false_w]).unwrap(); + + let case2_b = cond_builder.case_builder(1).unwrap(); + let [c2a] = case2_b.input_wires_arr(); + let case2 = case2_b.finish_with_outputs([false_w,c2a]).unwrap(); + + let case3_b = cond_builder.case_builder(2).unwrap(); + let [c3_1,c3_2] = case3_b.input_wires_arr(); + let case3 = case3_b.finish_with_outputs([c3_1,false_w]).unwrap(); + + let cond = cond_builder.finish_sub_container().unwrap(); + + let [cond_o1,cond_o2] = cond.outputs_arr(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2,[PartialValue::variant(0,[])])); + machine.propolutate_out_wires([(arg_w, arg_pv)]); + let c = machine.run_hugr(&hugr); + + let cond_r1 = machine.read_out_wire_value(&c, cond_o1).unwrap(); + assert_eq!(cond_r1, Value::false_val()); + assert!(machine.read_out_wire_value(&c, cond_o2).is_none()); + + assert!(!machine.case_reachable(&c, case1.node())); + assert!(machine.case_reachable(&c, case2.node())); + assert!(machine.case_reachable(&c, case3.node())); +} diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index a5de247fa..9a63bab1e 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -216,43 +216,30 @@ pub(super) fn value_outputs( context.hugr().out_value_types(n).map(|x| x.0) } +// TODO rename, this is about expanding input variants into output rows // todo this should work for dataflowblocks too pub(super) fn tail_loop_worker<'a>( - context: &'a impl DFContext, - n: Node, output_p: IncomingPort, - control_variant: usize, + variant_tag: usize, + variant_len: usize, v: &'a PV, ) -> impl Iterator + 'a { - let tail_loop_op = context.hugr().get_optype(n).as_tail_loop().unwrap(); - let num_variant_vals = if control_variant == 0 { - tail_loop_op.just_inputs.len() - } else { - tail_loop_op.just_outputs.len() - }; - let hugr = context.hugr(); if output_p.index() == 0 { Either::Left( - (0..num_variant_vals) - .map(move |i| (i.into(), v.variant_field_value(control_variant, i))), + (0..variant_len) + .map(move |i| (i.into(), v.variant_field_value(variant_tag, i))), ) } else { - let v = if v.supports_tag(control_variant) { + let v = if v.supports_tag(variant_tag) { v.clone() } else { PV::bottom() }; Either::Right(std::iter::once(( - (num_variant_vals + output_p.index() - 1).into(), + (variant_len + output_p.index() - 1).into(), v, ))) } - .inspect(move |x| { - assert!(matches!( - hugr.get_optype(n).port_kind(x.0), - Some(EdgeKind::Value(_)) - )) - }) } #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone)] From aeac385281b690cd37b3c9950ec92beb3860beb4 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Jun 2024 07:07:42 +0100 Subject: [PATCH 11/12] don't need per-parent io node relations --- hugr-passes/src/const_fold2/datalog.rs | 60 +++++++++++++------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index fae2fa508..b300fdcc0 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -20,11 +20,13 @@ pub use utils::{TailLoopTermination, ValueRow, IO, PV}; ascent::ascent! { struct AscentProgram; relation context(C); + relation out_wire_value_proto(Node, OutgoingPort, PV); + relation node(C, Node); relation in_wire(C, Node, IncomingPort); relation out_wire(C, Node, OutgoingPort); relation parent_of_node(C, Node, Node); - relation out_wire_value_proto(Node, OutgoingPort, PV); + relation io_node(C, Node, Node, IO); lattice out_wire_value(C, Node, OutgoingPort, PV); lattice node_in_value_row(C, Node, ValueRow); lattice in_wire_value(C, Node, IncomingPort, PV); @@ -38,12 +40,15 @@ ascent::ascent! { parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.hugr().get_parent(*child); - // All out wire values are initialised to Bottom. If any value is Bottom after - // running we can infer that execution never reaches that value. + io_node(c, parent, child, io) <-- node(c, parent), + if let Some([i,o]) = c.hugr().get_io(*parent), + for (child,io) in [(i,IO::Input),(o,IO::Output)]; + // We support prepopulating out_wire_value via out_wire_value_proto. + // + // out wires that do not have prepopulation values are initialised to bottom. out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); out_wire_value(c, n, p, v) <-- out_wire(c,n,p) , out_wire_value_proto(n, p, v); - in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), if let Some((m,op)) = c.hugr().single_linked_output(*n, *ip), out_wire_value(c, m, op, v); @@ -52,6 +57,7 @@ ascent::ascent! { node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n); node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + // LoadConstant relation load_constant_node(C, Node); load_constant_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_load_constant(); @@ -59,6 +65,7 @@ ascent::ascent! { out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <-- load_constant_node(c, n); + // MakeTuple relation make_tuple_node(C, Node); make_tuple_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_make_tuple(); @@ -66,6 +73,7 @@ ascent::ascent! { out_wire_value(c, n, 0.into(), utils::partial_value_tuple_from_value_row(vs.clone())) <-- make_tuple_node(c, n), node_in_value_row(c, n, vs); + // UnpackTuple relation unpack_tuple_node(C, Node); unpack_tuple_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_unpack_tuple(); @@ -75,45 +83,38 @@ ascent::ascent! { in_wire_value(c, n, IncomingPort::from(0), v), out_wire(c, n, p); + // DFG relation dfg_node(C, Node); dfg_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_dfg(); - relation dfg_io_node(C, Node, Node, IO); - dfg_io_node(c,dfg,n,io) <-- dfg_node(c,dfg), - if let Some([i,o]) = c.hugr().get_io(*dfg), - for (n, io) in [(i, IO::Input), (o, IO::Output)]; - out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- - dfg_io_node(c,dfg,i, IO::Input), in_wire_value(c, dfg, p, v); - out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- - dfg_io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v); + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), + io_node(c, dfg, i, IO::Input), in_wire_value(c, dfg, p, v); + + out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), + io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v); // TailLoop relation tail_loop_node(C, Node); tail_loop_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_tail_loop(); - relation tail_loop_io_node(C, Node, Node, IO); - tail_loop_io_node(c,tl,n, io) <-- tail_loop_node(c,tl), - if let Some([i,o]) = c.hugr().get_io(*tl), - for (n,io) in [(i,IO::Input), (o, IO::Output)]; // inputs of tail loop propagate to Input node of child region - out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- - tail_loop_io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); - + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_node(c, tl), + io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); // Output node of child region propagate to Input node of child region - out_wire_value(c, i, input_p, v) <-- - tail_loop_io_node(c,tl,i, IO::Input), - tail_loop_io_node(c,tl,o, IO::Output), + out_wire_value(c, i, input_p, v) <-- tail_loop_node(c, tl), + io_node(c,tl,i, IO::Input), + io_node(c,tl,o, IO::Output), in_wire_value(c, o, output_p, output_v), if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(), let variant_len = tailloop.just_inputs.len(), for (input_p, v) in utils::tail_loop_worker(*output_p, 0, variant_len, output_v); // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl, p, v) <-- - tail_loop_io_node(c,tl,o, IO::Output), + out_wire_value(c, tl, p, v) <-- tail_loop_node(c, tl), + io_node(c,tl,o, IO::Output), in_wire_value(c, o, output_p, output_v), if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(), let variant_len = tailloop.just_outputs.len(), @@ -124,26 +125,23 @@ ascent::ascent! { tail_loop_node(c,tl); tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <-- tail_loop_node(c,tl), - tail_loop_io_node(c,tl,o, IO::Output), + io_node(c,tl,o, IO::Output), in_wire_value(c, o, Into::::into(0usize), v); + // Conditional relation conditional_node(C, Node); relation case_node(C,Node,usize, Node); - relation case_io_node(C, Node, Node, IO); conditional_node (c,n)<-- node(c, n), if c.hugr().get_optype(*n).is_conditional(); case_node(c,cond,i, case) <-- conditional_node(c,cond), for (i, case) in c.hugr().children(*cond).enumerate(), if c.hugr().get_optype(case).is_case(); - case_io_node(c,case, n, io) <-- case_node(c, _, _, case), - if let Some([i,o]) = c.hugr().get_io(*case), - for (n,io) in [(i,IO::Input), (o, IO::Output)]; // inputs of conditional propagate into case nodes out_wire_value(c, i_node, i_p, v) <-- case_node(c, cond, case_index, case), - case_io_node(c, case, i_node, IO::Input), + io_node(c, case, i_node, IO::Input), in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.hugr().get_optype(*cond).as_conditional(), let variant_len = conditional.sum_rows[*case_index].len(), @@ -152,7 +150,7 @@ ascent::ascent! { // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- case_node(c, cond, _, case), - case_io_node(c, case, o, IO::Output), + io_node(c, case, o, IO::Output), in_wire_value(c, o, o_p, v); lattice case_reachable(C, Node, Node, bool); From 067bf8da097389a7baeed77997165468991bcc76 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Jun 2024 08:39:49 +0100 Subject: [PATCH 12/12] tidying --- hugr-core/src/partial_value/test.rs | 8 +- hugr-passes/Cargo.toml | 2 + hugr-passes/src/const_fold2/datalog.rs | 50 +++--- hugr-passes/src/const_fold2/datalog/test.rs | 37 ++-- hugr-passes/src/const_fold2/datalog/utils.rs | 173 +++++++++++++++++-- 5 files changed, 216 insertions(+), 54 deletions(-) diff --git a/hugr-core/src/partial_value/test.rs b/hugr-core/src/partial_value/test.rs index f31c33642..35fbf5373 100644 --- a/hugr-core/src/partial_value/test.rs +++ b/hugr-core/src/partial_value/test.rs @@ -316,15 +316,19 @@ proptest! { #[test] fn bounded_lattice(v in any_partial_value()) { - prop_assert!(v <= PartialValue::Top); - prop_assert!(v >= PartialValue::Bottom); + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); } #[test] fn meet_join_self_noop(v1 in any_partial_value()) { let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); assert!(!subject.join_mut(v1.clone())); assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); assert!(!subject.meet_mut(v1.clone())); assert_eq!(subject, v1); } diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 06d9975ea..92967f9c9 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -27,3 +27,5 @@ extension_inference = ["hugr-core/extension_inference"] [dev-dependencies] rstest = { workspace = true } +proptest = { workspace = true } +proptest-derive = { workspace = true } diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index b300fdcc0..98e80d0a9 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -58,6 +58,10 @@ ascent::ascent! { node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + // Per node-type rules + // TODO do all leaf ops with a rule + // define `fn propagate_leaf_op(Context, Node, ValueRow) -> ValueRow + // LoadConstant relation load_constant_node(C, Node); load_constant_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_load_constant(); @@ -104,29 +108,35 @@ ascent::ascent! { io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); // Output node of child region propagate to Input node of child region - out_wire_value(c, i, input_p, v) <-- tail_loop_node(c, tl), - io_node(c,tl,i, IO::Input), - io_node(c,tl,o, IO::Output), - in_wire_value(c, o, output_p, output_v), - if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(), + out_wire_value(c, in_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,in_n, IO::Input), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 + if let Some(tailloop) = c.hugr().get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_inputs.len(), - for (input_p, v) in utils::tail_loop_worker(*output_p, 0, variant_len, output_v); + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) + ); // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl, p, v) <-- tail_loop_node(c, tl), - io_node(c,tl,o, IO::Output), - in_wire_value(c, o, output_p, output_v), - if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(), + out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 + if let Some(tailloop) = c.hugr().get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_outputs.len(), - for (p, v) in utils::tail_loop_worker(*output_p, 1, variant_len, output_v); + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) + ); - lattice tail_loop_termination(C,Node,OrdLattice); - tail_loop_termination(c,tl,TailLoopTermination::NeverTerminates.into()) <-- - tail_loop_node(c,tl); - tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <-- - tail_loop_node(c,tl), - io_node(c,tl,o, IO::Output), - in_wire_value(c, o, Into::::into(0usize), v); + lattice tail_loop_termination(C,Node,TailLoopTermination); + tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- + tail_loop_node(c,tl_n); + tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <-- + tail_loop_node(c,tl_n), + io_node(c,tl,out_n, IO::Output), + in_wire_value(c, out_n, IncomingPort::from(0), v); // Conditional @@ -145,7 +155,7 @@ ascent::ascent! { in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.hugr().get_optype(*cond).as_conditional(), let variant_len = conditional.sum_rows[*case_index].len(), - for (i_p, v) in utils::tail_loop_worker(*cond_in_p, *case_index, variant_len, cond_in_v); + for (i_p, v) in utils::outputs_for_variant(*cond_in_p, *case_index, variant_len, cond_in_v); // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- @@ -219,7 +229,7 @@ impl<'a, H: HugrView> Machine<'a, H> { self.program .tail_loop_termination .iter() - .find_map(|(c, n, v)| (c == context && n == &node).then_some(v.0.clone())) + .find_map(|(c, n, v)| (c == context && n == &node).then_some(*v)) .unwrap() } diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 563c83130..e35ee0a47 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -86,7 +86,7 @@ fn test_tail_loop_never_iterates() { let o_r = machine.read_out_wire_value(&c, tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( - TailLoopTermination::SingleIteration, + TailLoopTermination::ExactlyZeroContinues, machine.tail_loop_terminates(&c, tail_loop.node()) ) } @@ -96,22 +96,29 @@ fn test_tail_loop_always_iterates() { let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); let r_w = builder .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let true_w = builder.add_load_value(Value::true_val()); + let tlb = builder - .tail_loop_builder([], [], vec![BOOL_T].into()) + .tail_loop_builder([], [(BOOL_T,true_w)], vec![BOOL_T].into()) .unwrap(); - let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); - let [tl_o] = tail_loop.outputs_arr(); + + // r_w has tag 0, so we always continue; + // we put true in our "other_output", but we should not propagate this to + // output because r_w never supports 1. + let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); + + let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); let c = machine.run_hugr(&hugr); - // dbg!(&machine.tail_loop_io_node); - // dbg!(&machine.out_wire_value); - let o_r = machine.read_out_wire_partial_value(&c, tl_o).unwrap(); - assert_eq!(o_r, PartialValue::Bottom); + let o_r1 = machine.read_out_wire_partial_value(&c, tl_o1).unwrap(); + assert_eq!(o_r1, PartialValue::bottom()); + let o_r2 = machine.read_out_wire_partial_value(&c, tl_o2).unwrap(); + assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( - TailLoopTermination::NeverTerminates, + TailLoopTermination::bottom(), machine.tail_loop_terminates(&c, tail_loop.node()) ) } @@ -146,20 +153,20 @@ fn test_tail_loop_iterates_twice() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); // TODO once we can do conditionals put these wires inside `just_outputs` and // we should be able to propagate their values - // let [o_w1, o_w2, _] = tail_loop.outputs_arr(); + let [o_w1, o_w2, _] = tail_loop.outputs_arr(); let mut machine = Machine::new(); let c = machine.run_hugr(&hugr); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); - // TODO these hould be the propagated values - // let o_r1 = machine.read_out_wire_value(&c, o_w1).unwrap(); - // assert_eq!(o_r1, Value::false_val()); - // let o_r2 = machine.read_out_wire_value(&c, o_w2).unwrap(); + // TODO these hould be the propagated values for now they will bt join(true,false) + let o_r1 = machine.read_out_wire_partial_value(&c, o_w1).unwrap(); + // assert_eq!(o_r1, PartialValue::top()); + let o_r2 = machine.read_out_wire_partial_value(&c, o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( - TailLoopTermination::Terminates, + TailLoopTermination::Top, machine.tail_loop_terminates(&c, tail_loop.node()) ) } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 9a63bab1e..00162e73b 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -1,4 +1,11 @@ -use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; +// proptest-derive generates many of these warnings. +// https://github.com/rust-lang/rust/issues/120363 +// https://github.com/proptest-rs/proptest/issues/447 +#![cfg_attr(test, allow(non_local_definitions))] + +use std::{cmp::Ordering, ops::Index}; + +use ascent::lattice::{BoundedLattice, Lattice}; use either::Either; use hugr_core::{ ops::OpTrait as _, @@ -8,6 +15,9 @@ use hugr_core::{ }; use itertools::zip_eq; +#[cfg(test)] +use proptest_derive::Arbitrary; + use super::context::DFContext; #[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] @@ -98,7 +108,7 @@ impl ValueRow { Self::new(r.len()) } - fn iter<'b>( + pub fn iter<'b>( &'b self, context: &'b impl DFContext, n: Node, @@ -151,6 +161,16 @@ impl IntoIterator for ValueRow { } } +impl Index for ValueRow where + Vec: Index +{ + type Output = as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + pub(super) fn bottom_row(context: &impl DFContext, n: Node) -> ValueRow { if let Some(sig) = context.hugr().signature(n) { ValueRow::new(sig.input_count()) @@ -216,9 +236,26 @@ pub(super) fn value_outputs( context.hugr().out_value_types(n).map(|x| x.0) } -// TODO rename, this is about expanding input variants into output rows -// todo this should work for dataflowblocks too -pub(super) fn tail_loop_worker<'a>( +// We have several cases where sum types propagate to different places depending +// on their variant tag: +// - From the input of a conditional to the inputs of it's case nodes +// - From the input of the output node of a tail loop to the output of the input node of the tail loop +// - From the input of the output node of a tail loop to the output of tail loop node +// - From the input of a the output node of a dataflow block to the output of the input node of a dataflow block +// - From the input of a the output node of a dataflow block to the output of the cfg +// +// For a value `v` on an incoming porg `output_p`, compute the (out port,value) +// pairs that should be propagated for a given variant tag. We must also supply +// the length of this variant because it cannot always be deduced from the other +// inputs. +// +// If `v` does not support `variant_tag`, then all propagated values will be bottom.` +// +// If `output_p.index()` is 0 then the result is the contents of the variant. +// Otherwise, it is the single "other_output". +// +// TODO doctests +pub(super) fn outputs_for_variant<'a>( output_p: IncomingPort, variant_tag: usize, variant_len: usize, @@ -242,27 +279,129 @@ pub(super) fn tail_loop_worker<'a>( } } -#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone)] +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +#[cfg_attr(test, derive(Arbitrary))] pub enum TailLoopTermination { - NeverTerminates, - SingleIteration, - Terminates, + Bottom, + ExactlyZeroContinues, + Top, } impl TailLoopTermination { pub fn from_control_value(v: &PV) -> Self { - if v.supports_tag(1) && !v.supports_tag(0) { - Self::SingleIteration - } else if v.supports_tag(1) { - Self::Terminates + let (may_continue, may_break) = (v.supports_tag(0),v.supports_tag(1)); + if may_break && !may_continue { + Self::ExactlyZeroContinues + } else if may_break && may_continue { + Self::top() } else { - Self::NeverTerminates + Self::bottom() + } + } +} + +impl PartialOrd for TailLoopTermination { + fn partial_cmp(&self, other: &Self) -> Option { + if self == other { + return Some(std::cmp::Ordering::Equal); + }; + match (self, other) { + (Self::Bottom,_) => Some(Ordering::Less), + (_,Self::Bottom) => Some(Ordering::Greater), + (Self::Top,_) => Some(Ordering::Greater), + (_,Self::Top) => Some(Ordering::Less), + _ => None } } } -impl From for OrdLattice { - fn from(value: TailLoopTermination) -> Self { - Self(value) +impl Lattice for TailLoopTermination { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn meet_mut(&mut self, other: Self) -> bool { + // let new_self = &mut self; + match (*self).partial_cmp(&other) { + Some(Ordering::Greater) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Bottom; + true + } + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match (*self).partial_cmp(&other) { + Some(Ordering::Less) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Top; + true + } + } + } +} + +impl BoundedLattice for TailLoopTermination { + fn bottom() -> Self { + Self::Bottom + + } + + fn top() -> Self { + Self::Top + } +} + +#[cfg(test)] +#[cfg_attr(test, allow(non_local_definitions))] +mod test { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn bounded_lattice(v: TailLoopTermination) { + prop_assert!(v <= TailLoopTermination::top()); + prop_assert!(v >= TailLoopTermination::bottom()); + } + + #[test] + fn meet_join_self_noop(v1: TailLoopTermination) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } } }