diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index 6bd2a262d..318bd02fc 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -14,6 +14,7 @@ pub mod extension; pub mod hugr; pub mod macros; pub mod ops; +pub mod partial_value; pub mod std_extensions; pub mod types; pub mod utils; diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 5f059f4f0..486c910e9 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -98,7 +98,7 @@ impl AsRef for Const { } } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[serde(tag = "v")] /// A value that can be stored as a static constant. Representing core types and /// extension types. @@ -136,6 +136,49 @@ pub enum Value { }, } +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + let cmp_tuple_sum = + |tuple_values: &[Value], tag: usize, sum_values: &[Value], sum_type: &SumType| { + tag == 0 && sum_type.num_variants() == 1 && tuple_values == sum_values + }; + match (self, other) { + (Self::Extension { e: e1 }, Self::Extension { e: e2 }) => e1 == e2, + (Self::Function { hugr: h1 }, Self::Function { hugr: h2 }) => h1 == h2, + (Self::Tuple { vs: v1 }, Self::Tuple { vs: v2 }) => v1 == v2, + ( + Self::Sum { + tag: t1, + values: v1, + sum_type: s1, + }, + Self::Sum { + tag: t2, + values: v2, + sum_type: s2, + }, + ) => t1 == t2 && v1 == v2 && s1 == s2, + ( + Self::Tuple { vs }, + Self::Sum { + tag, + values, + sum_type, + }, + ) + | ( + Self::Sum { + tag, + values, + sum_type, + }, + Self::Tuple { vs }, + ) => cmp_tuple_sum(vs, *tag, values, sum_type), + _ => false, + } + } +} + /// An opaque newtype around a [`Box`](CustomConst). /// /// This type has special serialization behaviour in order to support diff --git a/hugr-core/src/ops/module.rs b/hugr-core/src/ops/module.rs index 40c189b68..fef5105fb 100644 --- a/hugr-core/src/ops/module.rs +++ b/hugr-core/src/ops/module.rs @@ -128,7 +128,7 @@ impl OpTrait for AliasDefn { } /// A type alias declaration. Resolved at link time. -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] pub struct AliasDecl { /// Alias name diff --git a/hugr-core/src/partial_value.rs b/hugr-core/src/partial_value.rs new file mode 100644 index 000000000..cf4e304f4 --- /dev/null +++ b/hugr-core/src/partial_value.rs @@ -0,0 +1,487 @@ +#![allow(missing_docs)] +use std::cmp::Ordering; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; + +use itertools::{zip_eq, Itertools as _}; + +use crate::ops::Value; +use crate::types::{Type, TypeEnum}; + +mod value_handle; + +pub use value_handle::{ValueHandle, ValueKey}; + +/// TODO shouldn't be pub +#[derive(PartialEq, Clone, Eq)] +pub struct PartialSum(HashMap>); + +impl PartialSum { + pub fn unit() -> Self { + Self::variant(0, []) + } + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + Self([(tag, values.into_iter().collect())].into_iter().collect()) + } + + pub fn num_variants(&self) -> usize { + self.0.len() + } + + fn assert_variants(&self) { + assert_ne!(self.num_variants(), 0); + for pv in self.0.values().flat_map(|x| x.iter()) { + pv.assert_invariants(); + } + } + + pub fn variant_field_value(&self, variant: usize, idx: usize) -> PartialValue { + if let Some(row) = self.0.get(&variant) { + assert!(row.len() > idx); + row[idx].clone() + } else { + PartialValue::bottom() + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), + Err(_) => Err(self), + } + } + + // unsafe because we panic if any common rows have different lengths + fn join_mut_unsafe(&mut self, other: Self) -> bool { + let mut changed = false; + + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + self.0.insert(k, v); + changed = true; + } + } + changed + } + + // unsafe because we panic if any common rows have different lengths + fn meet_mut_unsafe(&mut self, other: Self) -> bool { + let mut changed = false; + let mut keys_to_remove = vec![]; + for k in self.0.keys() { + if !other.0.contains_key(k) { + keys_to_remove.push(*k); + } + } + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.meet_mut(rhs); + } + } else { + keys_to_remove.push(k); + } + } + for k in keys_to_remove { + self.0.remove(&k); + changed = true; + } + changed + } + + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } +} + +impl PartialOrd for PartialSum { + fn partial_cmp(&self, other: &Self) -> Option { + let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); + let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); + for k in self.0.keys() { + keys1[*k] = 1; + } + + for k in other.0.keys() { + keys2[*k] = 1; + } + + if let Some(ord) = keys1.partial_cmp(&keys2) { + if ord != Ordering::Equal { + return Some(ord); + } + } else { + return None; + } + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(&k) else { + unreachable!() + }; + match lhs.partial_cmp(rhs) { + Some(Ordering::Equal) => continue, + x => { + return x; + } + } + } + Some(Ordering::Equal) + } +} + +impl std::fmt::Debug for PartialSum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Hash for PartialSum { + fn hash(&self, state: &mut H) { + for (k, v) in &self.0 { + k.hash(state); + v.hash(state); + } + } +} + +impl TryFrom for PartialSum { + type Error = ValueHandle; + + fn try_from(value: ValueHandle) -> Result { + match value.value() { + Value::Tuple { vs } => { + let vec = (0..vs.len()) + .map(|i| PartialValue::from(value.index(i)).into()) + .collect(); + return Ok(Self([(0, vec)].into_iter().collect())); + } + Value::Sum { tag, values, .. } => { + let vec = (0..values.len()) + .map(|i| PartialValue::from(value.index(i)).into()) + .collect(); + return Ok(Self([(*tag, vec)].into_iter().collect())); + } + _ => (), + }; + Err(value) + } +} + +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + Bottom, + Value(ValueHandle), + PartialSum(PartialSum), + Top, +} + +impl From for PartialValue { + fn from(v: ValueHandle) -> Self { + TryInto::::try_into(v).map_or_else(Self::Value, Self::PartialSum) + } +} + +impl From for PartialValue { + fn from(v: PartialSum) -> Self { + Self::PartialSum(v) + } +} + +impl PartialValue { + // const BOTTOM: Self = Self::Bottom; + // const BOTTOM_REF: &'static Self = &Self::BOTTOM; + + // fn initialised(&self) -> bool { + // !self.is_top() + // } + + // fn is_top(&self) -> bool { + // self == &PartialValue::Top + // } + + fn assert_invariants(&self) { + match self { + Self::PartialSum(ps) => { + ps.assert_variants(); + } + Self::Value(v) => { + assert!(matches!(v.clone().into(), Self::Value(_))) + } + _ => {} + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let r = match self { + Self::Value(v) => Ok(v.value().clone()), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + }?; + assert_eq!(typ, &r.get_type()); + Ok(r) + } + + fn join_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + self.assert_invariants(); + let mut new_self = self; + match &mut new_self { + Self::Top => false, + new_self @ Self::Value(_) => { + let Self::Value(v) = *new_self else { + unreachable!() + }; + if v == &vh { + false + } else { + **new_self = Self::Top; + true + } + } + s @ Self::PartialSum(_) => match vh.into() { + Self::Value(_) => { + **s = Self::Top; + true + } + other => s.join_mut(other), + }, + new_self @ Self::Bottom => { + **new_self = vh.into(); + true + } + } + } + + fn meet_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + self.assert_invariants(); + let mut new_self = self; + match &mut new_self { + Self::Bottom => false, + new_self @ Self::Value(_) => { + let Self::Value(v) = *new_self else { + unreachable!() + }; + if v == &vh { + false + } else { + **new_self = Self::Bottom; + true + } + } + new_self @ Self::PartialSum(_) => match vh.into() { + Self::Value(_) => { + **new_self = Self::Bottom; + true + } + other => new_self.join_mut(other), + }, + new_self @ Self::Top => { + **new_self = vh.into(); + true + } + } + } + + fn value_handles_equal(&self, rhs: &ValueHandle) -> bool { + let Self::Value(lhs) = self else { + unreachable!() + }; + lhs == rhs + // The following is a good idea if ValueHandle gains an Eq + // instance and so does not do this check: + // || lhs.value() == rhs.value() + } + + pub fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + pub fn join_mut(&mut self, other: Self) -> bool { + // println!("join {self:?}\n{:?}", &other); + let mut new_self = self; + let changed = match (&mut new_self, other) { + (Self::Top, _) => false, + (new_self, other @ Self::Top) => { + **new_self = other; + true + } + (_, Self::Bottom) => false, + (new_self @ Self::Bottom, other) => { + **new_self = other; + true + } + (new_self @ Self::Value(_), Self::Value(h2)) => { + if new_self.value_handles_equal(&h2) { + false + } else { + **new_self = Self::Top; + true + } + } + (new_self @ Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = *new_self else { + unreachable!() + }; + + ps1.join_mut_unsafe(ps2) + } + (new_self @ Self::Value(_), other) => { + let mut old_self = other; + std::mem::swap(*new_self, &mut old_self); + let Self::Value(h) = old_self else { + unreachable!() + }; + new_self.join_mut_value_handle(h) + } + (new_self, Self::Value(h)) => new_self.join_mut_value_handle(h), + // (new_self, _) => { + // **new_self = Self::Top; + // false + // } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + pub fn meet_mut(&mut self, other: Self) -> bool { + let mut new_self = self; + let changed = match (&mut new_self, other) { + (Self::Bottom, _) => false, + (new_self, other @ Self::Bottom) => { + **new_self = other; + true + } + (_, Self::Top) => false, + (new_self @ Self::Top, other) => { + **new_self = other; + true + } + (new_self @ Self::Value(_), Self::Value(h2)) => { + if new_self.value_handles_equal(&h2) { + false + } else { + **new_self = Self::Bottom; + true + } + } + (new_self @ Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = *new_self else { + unreachable!() + }; + ps1.meet_mut_unsafe(ps2) + } + (new_self @ Self::Value(_), other @ Self::PartialSum(_)) => { + let mut old_self = other; + std::mem::swap(*new_self, &mut old_self); + let Self::Value(h) = old_self else { + unreachable!() + }; + new_self.meet_mut_value_handle(h) + } + (s @ Self::PartialSum(_), Self::Value(h)) => s.meet_mut_value_handle(h), + // (new_self, _) => { + // **new_self = Self::Bottom; + // false + // } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn top() -> Self { + Self::Top + } + + pub fn bottom() -> Self { + Self::Bottom + } + + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::variant(tag, values).into() + } + + pub fn unit() -> Self { + Self::variant(0, []) + } + + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom => false, + PartialValue::Value(v) => v.tag() == tag, // can never be a sum or tuple + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// TODO docs + /// just delegate to variant_field_value + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO docs + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + match self { + Self::Bottom => Self::Bottom, + Self::PartialSum(ps) => ps.variant_field_value(variant, idx), + Self::Value(v) => { + if v.tag() == variant { + Self::Value(v.index(idx)) + } else { + Self::Bottom + } + } + Self::Top => Self::Top, + } + } +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(_), Self::Value(v2)) => { + self.value_handles_equal(v2).then_some(Ordering::Equal) + } + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), + _ => None, + } + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-core/src/partial_value/test.rs b/hugr-core/src/partial_value/test.rs new file mode 100644 index 000000000..35fbf5373 --- /dev/null +++ b/hugr-core/src/partial_value/test.rs @@ -0,0 +1,346 @@ +use std::sync::Arc; + +use itertools::{zip_eq, Either, Itertools as _}; +use lazy_static::lazy_static; +use proptest::prelude::*; + +use crate::{ + ops::Value, + std_extensions::arithmetic::int_types::{ + self, get_log_width, ConstInt, INT_TYPES, LOG_WIDTH_BOUND, + }, + types::{CustomType, Type, TypeEnum}, +}; + +use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; +impl Arbitrary for ValueHandle { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + // prop_oneof![ + + // ] + todo!() + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum TestSumLeafType { + Int(Type), + Unit, +} + +impl TestSumLeafType { + fn assert_invariants(&self) { + match self { + Self::Int(t) => { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); + } + } + _ => (), + } + } + + fn get_type(&self) -> Type { + match self { + Self::Int(t) => t.clone(), + Self::Unit => Type::UNIT, + } + } + + fn type_check(&self, ps: &PartialSum) -> bool { + match self { + Self::Int(_) => false, + Self::Unit => { + if let Ok((0, v)) = ps.0.iter().exactly_one() { + v.is_empty() + } else { + false + } + } + } + } + + fn partial_value_strategy(self) -> impl Strategy { + match self { + Self::Int(t) => { + let TypeEnum::Extension(ct) = t.as_type_enum() else { + unreachable!() + }; + let lw = get_log_width(&ct.args()[0]).unwrap(); + (0u64..(1 << (2u64.pow(lw as u32) - 1))) + .prop_map(move |x| { + let ki = ConstInt::new_u(lw, x).unwrap(); + ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() + }) + .boxed() + } + Self::Unit => Just(PartialSum::unit().into()).boxed(), + } + } +} + +impl Arbitrary for TestSumLeafType { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + let int_strat = (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); + prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum TestSumType { + Branch(usize, Vec>>), + Leaf(TestSumLeafType), +} + +impl TestSumType { + const UNIT: TestSumLeafType = TestSumLeafType::Unit; + + fn leaf(v: Type) -> Self { + TestSumType::Leaf(TestSumLeafType::Int(v)) + } + + fn branch(vs: impl IntoIterator>>) -> Self { + let vec = vs.into_iter().collect_vec(); + let depth: usize = vec + .iter() + .flat_map(|x| x.iter()) + .map(|x| x.depth() + 1) + .max() + .unwrap_or(0); + Self::Branch(depth, vec) + } + + fn depth(&self) -> usize { + match self { + TestSumType::Branch(x, _) => *x, + TestSumType::Leaf(_) => 0, + } + } + + fn is_leaf(&self) -> bool { + self.depth() == 0 + } + + fn assert_invariants(&self) { + match self { + TestSumType::Branch(d, sop) => { + assert!(!sop.is_empty(), "No variants"); + for v in sop.iter().flat_map(|x| x.iter()) { + assert!(v.depth() < *d); + v.assert_invariants(); + } + } + TestSumType::Leaf(l) => { + l.assert_invariants(); + } + _ => (), + } + } + + fn select(self) -> impl Strategy>)>> { + match self { + TestSumType::Branch(_, sop) => any::() + .prop_map(move |i| { + let index = i.index(sop.len()); + Either::Right((index, sop[index].clone())) + }) + .boxed(), + TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), + } + } + + fn get_type(&self) -> Type { + match self { + TestSumType::Branch(_, sop) => Type::new_sum( + sop.iter() + .map(|row| row.iter().map(|x| x.get_type()).collect_vec().into()), + ), + TestSumType::Leaf(l) => l.get_type(), + } + } + + fn type_check(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), + (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { + return false; + } + } + true + } + (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), + } + } +} + +impl From for TestSumType { + fn from(value: TestSumLeafType) -> Self { + Self::Leaf(value) + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +struct UnarySumTypeParams { + depth: usize, + branch_width: usize, +} + +impl UnarySumTypeParams { + pub fn descend(mut self, d: usize) -> Self { + assert!(d < self.depth); + self.depth = d; + self + } +} + +impl Default for UnarySumTypeParams { + fn default() -> Self { + Self { + depth: 3, + branch_width: 3, + } + } +} + +impl Arbitrary for TestSumType { + type Parameters = UnarySumTypeParams; + type Strategy = BoxedStrategy; + fn arbitrary_with( + params @ UnarySumTypeParams { + depth, + branch_width, + }: Self::Parameters, + ) -> Self::Strategy { + if depth == 0 { + any::().prop_map_into().boxed() + } else { + (0..depth) + .prop_flat_map(move |d| { + prop::collection::vec( + prop::collection::vec( + any_with::(params.clone().descend(d)).prop_map_into(), + 0..branch_width, + ), + 1..=branch_width, + ) + .prop_map(TestSumType::branch) + }) + .boxed() + } + } +} + +proptest! { + #[test] + fn unary_sum_type_valid(ust: TestSumType) { + ust.assert_invariants(); + } +} + +fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { + ust.select().prop_flat_map(|x| match x { + Either::Left(l) => l.partial_value_strategy().boxed(), + Either::Right((index, usts)) => { + let pvs = usts + .into_iter() + .map(|x| { + any_partial_value_of_type( + Arc::::try_unwrap(x).unwrap_or_else(|x| x.as_ref().clone()), + ) + }) + .collect_vec(); + pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + .boxed() + } + }) +} + +fn any_partial_value_with( + params: ::Parameters, +) -> impl Strategy { + any_with::(params).prop_flat_map(any_partial_value_of_type) +} + +fn any_partial_value() -> impl Strategy { + any_partial_value_with(Default::default()) +} + +fn any_partial_values() -> impl Strategy { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(ust.clone())) + .collect_vec(), + ) + .unwrap() + }) +} + +fn any_typed_partial_value() -> impl Strategy { + any::() + .prop_flat_map(|t| any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v))) +} + +proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.type_check(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); + } + + #[test] + fn meet_join_self_noop(v1 in any_partial_value()) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } +} diff --git a/hugr-core/src/partial_value/value_handle.rs b/hugr-core/src/partial_value/value_handle.rs new file mode 100644 index 000000000..dfb019872 --- /dev/null +++ b/hugr-core/src/partial_value/value_handle.rs @@ -0,0 +1,245 @@ +use std::any::Any; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +use downcast_rs::Downcast; +use itertools::Either; + +use crate::ops::Value; +use crate::std_extensions::arithmetic::int_types::ConstInt; +use crate::Node; + +pub trait ValueName: std::fmt::Debug + Downcast + Any { + fn hash(&self) -> u64; + fn eq(&self, other: &dyn ValueName) -> bool; +} + +fn hash_hash(x: &impl Hash) -> u64 { + let mut hasher = DefaultHasher::new(); + x.hash(&mut hasher); + hasher.finish() +} + +fn value_name_eq(x: &T, other: &dyn ValueName) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + x == other + } else { + false + } +} + +impl ValueName for String { + fn hash(&self) -> u64 { + hash_hash(self) + } + + fn eq(&self, other: &dyn ValueName) -> bool { + value_name_eq(self, other) + } +} + +impl ValueName for ConstInt { + fn hash(&self) -> u64 { + hash_hash(self) + } + + fn eq(&self, other: &dyn ValueName) -> bool { + value_name_eq(self, other) + } +} + +#[derive(Clone, Debug)] +pub struct ValueKey(Vec, Either>); + +impl PartialEq for ValueKey { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + && match (&self.1, &other.1) { + (Either::Left(ref n1), Either::Left(ref n2)) => n1 == n2, + (Either::Right(ref v1), Either::Right(ref v2)) => v1.eq(v2.as_ref()), + _ => false, + } + } +} + +impl Eq for ValueKey {} + +impl Hash for ValueKey { + fn hash(&self, state: &mut H) { + self.0.hash(state); + match &self.1 { + Either::Left(n) => (0, n).hash(state), + Either::Right(v) => (1, v.hash()).hash(state), + } + } +} + +impl From for ValueKey { + fn from(n: Node) -> Self { + Self(vec![], Either::Left(n)) + } +} + +impl ValueKey { + pub fn new(k: impl ValueName) -> Self { + Self(vec![], Either::Right(Arc::new(k))) + } + + pub fn index(self, i: usize) -> Self { + let mut is = self.0; + is.push(i); + Self(is, self.1) + } +} + +#[derive(Clone, Debug)] +pub struct ValueHandle(ValueKey, Arc); + +impl ValueHandle { + pub fn new(key: ValueKey, value: Arc) -> Self { + Self(key, value) + } + + pub fn value(&self) -> &Value { + self.1.as_ref() + } + + pub fn is_compound(&self) -> bool { + match self.value() { + Value::Sum { .. } | Value::Tuple { .. } => true, + _ => false, + } + } + + pub fn num_fields(&self) -> usize { + assert!( + self.is_compound(), + "ValueHandle::num_fields called on non-Sum, non-Tuple value: {:#?}", + self + ); + match self.value() { + Value::Sum { values, .. } => values.len(), + Value::Tuple { vs } => vs.len(), + _ => unreachable!(), + } + } + + pub fn tag(&self) -> usize { + assert!( + self.is_compound(), + "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", + self + ); + match self.value() { + Value::Sum { tag, .. } => *tag, + Value::Tuple { .. } => 0, + _ => unreachable!(), + } + } + + pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { + assert!( + i < self.num_fields(), + "ValueHandle::index called with out-of-bounds index {}: {:#?}", + i, + &self + ); + let vs = match self.value() { + Value::Sum { values, .. } => values, + Value::Tuple { vs, .. } => vs, + _ => unreachable!(), + }; + let v = vs[i].clone().into(); + Self(self.0.clone().index(i), v) + } +} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + // If the keys are equal, we return true since the values must have the + // same provenance, and so be equal. If the keys are different but the + // values are equal, we could return true if we didn't impl Eq, but + // since we do impl Eq, the Hash contract prohibits us from having equal + // values with different hashes. + let r = self.0 == other.0; + if r { + debug_assert_eq!(self.get_type(), other.get_type()); + } + r + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + self.0.hash(state); + } +} + +/// TODO this is perhaps dodgy +/// we do not hash or compare the value, just the key +/// this means two handles with different keys, but with the same value, will +/// not compare equal. +impl Deref for ValueHandle { + type Target = Value; + + fn deref(&self) -> &Self::Target { + self.value() + } +} + +#[cfg(test)] +mod test { + use crate::{ops::constant::CustomConst as _, types::SumType}; + + use super::*; + + #[test] + fn value_key_eq() { + let k1 = ValueKey::new("foo".to_string()); + let k2 = ValueKey::new("foo".to_string()); + let k3 = ValueKey::new("bar".to_string()); + + assert_eq!(k1, k2); + assert_ne!(k1, k3); + + let k4: ValueKey = From::::from(portgraph::NodeIndex::new(1).into()); + let k5 = From::::from(portgraph::NodeIndex::new(1).into()); + let k6 = From::::from(portgraph::NodeIndex::new(2).into()); + + assert_eq!(&k4, &k5); + assert_ne!(&k4, &k6); + + let k7 = k5.clone().index(3); + let k4 = k4.index(3); + + assert_eq!(&k4, &k7); + + let k5 = k5.index(2); + + assert_ne!(&k5, &k7); + } + + #[test] + fn value_handle_eq() { + let k_i = ConstInt::new_u(4, 2).unwrap(); + let subject_val = Arc::new( + Value::sum( + 0, + [k_i.clone().into()], + SumType::new([vec![k_i.get_type()], vec![]]), + ) + .unwrap(), + ); + + let k1 = ValueKey::new("foo".to_string()); + let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); + let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); + + // we do not compare the value, just the key + assert_ne!(v1.index(0), v2); + assert_eq!(v1.index(0).value(), v2.value()); + } +} diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 5ff2abea3..d79a4b84f 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -63,7 +63,7 @@ pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat({ /// Get the log width of the specified type argument or error if the argument /// is invalid. -pub(super) fn get_log_width(arg: &TypeArg) -> Result { +pub(crate) fn get_log_width(arg: &TypeArg) -> Result { match arg { TypeArg::BoundedNat { n } if is_valid_log_width(*n as u8) => Ok(*n as u8), _ => Err(TypeArgError::TypeMismatch { @@ -80,7 +80,7 @@ const fn type_arg(log_width: u8) -> TypeArg { } /// An integer (either signed or unsigned) -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)] pub struct ConstInt { log_width: u8, // We always use a u64 for the value. The interpretation is: diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index cb8cd0706..dd1cd9b29 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -118,7 +118,7 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> Ty .into_inner() } -#[derive(Clone, PartialEq, Debug, Eq, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)] #[serde(tag = "s")] #[non_exhaustive] /// Representation of a Sum type. @@ -196,7 +196,7 @@ impl From for Type { } } -#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)] +#[derive(Clone, PartialEq, Debug, Eq, Hash, derive_more::Display)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] /// Core types pub enum TypeEnum { @@ -249,7 +249,7 @@ impl TypeEnum { } #[derive( - Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, + Clone, PartialEq, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, )] #[display(fmt = "{}", "_0")] #[serde(into = "serialize::SerSimpleType", from = "serialize::SerSimpleType")] diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 93d2e506a..c3c4c3f7e 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -12,7 +12,7 @@ use super::{ use super::{Type, TypeName}; /// An opaque type element. Contains the unique identifier of its definition. -#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] pub struct CustomType { extension: ExtensionId, /// Unique identifier of the opaque type. diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 1cf731eda..542b3fc0b 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -14,7 +14,7 @@ use crate::{Direction, IncomingPort, OutgoingPort, Port}; #[cfg(test)] use {crate::proptest::RecursionDepth, ::proptest::prelude::*, proptest_derive::Arbitrary}; -#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] /// Describes the edges required to/from a node, and thus, also the type of a [Graph]. /// This includes both the concept of "signature" in the spec, diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index e7019fe7c..572ac1220 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -19,7 +19,7 @@ use super::{check_typevar_decl, CustomType, Substitution, Type, TypeBound}; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] // A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid) #[derive( - Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Deserialize, serde::Serialize, + Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[display(fmt = "{}", "_0.map(|i|i.to_string()).unwrap_or(\"-\".to_string())")] #[cfg_attr(test, derive(Arbitrary))] @@ -52,7 +52,7 @@ impl UpperBound { /// [PolyFuncType]: super::PolyFuncType /// [OpDef]: crate::extension::OpDef #[derive( - Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Deserialize, serde::Serialize, + Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[non_exhaustive] #[serde(tag = "tp")] @@ -142,7 +142,7 @@ impl From for TypeParam { } /// A statically-known argument value to an operation. -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] #[non_exhaustive] #[serde(tag = "tya")] pub enum TypeArg { @@ -214,7 +214,7 @@ impl From for TypeArg { } /// Variable in a TypeArg, that is not a [TypeArg::Type] or [TypeArg::Extensions], -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] pub struct TypeArgVariable { idx: usize, cached_decl: TypeParam, @@ -352,7 +352,7 @@ impl TypeArgVariable { /// A serialized representation of a value of a [CustomType] /// restricted to equatable types. -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] pub struct CustomTypeArg { /// The type of the constant. /// (Exact matches only - the constant is exactly this type.) diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index 5be3ddc2f..17182f788 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -16,7 +16,7 @@ use delegate::delegate; use itertools::Itertools; /// List of types, used for function signatures. -#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, Eq, Debug, Hash, serde::Serialize, serde::Deserialize)] #[non_exhaustive] #[serde(transparent)] pub struct TypeRow { diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 240468f4c..92967f9c9 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -18,9 +18,14 @@ itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } +ascent = "0.6.0" +either = "*" +delegate = "*" [features] extension_inference = ["hugr-core/extension_inference"] [dev-dependencies] rstest = { workspace = true } +proptest = { workspace = true } +proptest-derive = { workspace = true } diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs new file mode 100644 index 000000000..f01a78c6e --- /dev/null +++ b/hugr-passes/src/const_fold2.rs @@ -0,0 +1 @@ +mod datalog; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs new file mode 100644 index 000000000..98e80d0a9 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -0,0 +1,247 @@ +use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; +use delegate::delegate; +use itertools::{zip_eq, Itertools}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; + +use either::Either; +use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::partial_value::{PartialValue, ValueHandle}; +use hugr_core::types::{EdgeKind, FunctionType, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +mod context; +mod utils; + +pub use context::{ArcDataflowContext, DFContext, ValueCache}; +pub use utils::{TailLoopTermination, ValueRow, IO, PV}; + +ascent::ascent! { + struct AscentProgram; + relation context(C); + relation out_wire_value_proto(Node, OutgoingPort, PV); + + relation node(C, Node); + relation in_wire(C, Node, IncomingPort); + relation out_wire(C, Node, OutgoingPort); + relation parent_of_node(C, Node, Node); + relation io_node(C, Node, Node, IO); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); + + node(c, n) <-- context(c), for n in c.hugr().nodes(); + + in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c, *n); + + out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c, *n); + + parent_of_node(c, parent, child) <-- + node(c, child), if let Some(parent) = c.hugr().get_parent(*child); + + io_node(c, parent, child, io) <-- node(c, parent), + if let Some([i,o]) = c.hugr().get_io(*parent), + for (child,io) in [(i,IO::Input),(o,IO::Output)]; + // We support prepopulating out_wire_value via out_wire_value_proto. + // + // out wires that do not have prepopulation values are initialised to bottom. + out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); + out_wire_value(c, n, p, v) <-- out_wire(c,n,p) , out_wire_value_proto(n, p, v); + + in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), + if let Some((m,op)) = c.hugr().single_linked_output(*n, *ip), + out_wire_value(c, m, op, v); + + + node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n); + node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + + + // Per node-type rules + // TODO do all leaf ops with a rule + // define `fn propagate_leaf_op(Context, Node, ValueRow) -> ValueRow + + // LoadConstant + relation load_constant_node(C, Node); + load_constant_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_load_constant(); + + out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <-- + load_constant_node(c, n); + + + // MakeTuple + relation make_tuple_node(C, Node); + make_tuple_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_make_tuple(); + + out_wire_value(c, n, 0.into(), utils::partial_value_tuple_from_value_row(vs.clone())) <-- + make_tuple_node(c, n), node_in_value_row(c, n, vs); + + + // UnpackTuple + relation unpack_tuple_node(C, Node); + unpack_tuple_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_unpack_tuple(); + + out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- + unpack_tuple_node(c, n), + in_wire_value(c, n, IncomingPort::from(0), v), + out_wire(c, n, p); + + + // DFG + relation dfg_node(C, Node); + dfg_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_dfg(); + + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), + io_node(c, dfg, i, IO::Input), in_wire_value(c, dfg, p, v); + + out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), + io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v); + + + // TailLoop + relation tail_loop_node(C, Node); + tail_loop_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_tail_loop(); + + // inputs of tail loop propagate to Input node of child region + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_node(c, tl), + io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); + + // Output node of child region propagate to Input node of child region + out_wire_value(c, in_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,in_n, IO::Input), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 + if let Some(tailloop) = c.hugr().get_optype(*tl_n).as_tail_loop(), + let variant_len = tailloop.just_inputs.len(), + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) + ); + + // Output node of child region propagate to outputs of tail loop + out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 + if let Some(tailloop) = c.hugr().get_optype(*tl_n).as_tail_loop(), + let variant_len = tailloop.just_outputs.len(), + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) + ); + + lattice tail_loop_termination(C,Node,TailLoopTermination); + tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- + tail_loop_node(c,tl_n); + tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <-- + tail_loop_node(c,tl_n), + io_node(c,tl,out_n, IO::Output), + in_wire_value(c, out_n, IncomingPort::from(0), v); + + + // Conditional + relation conditional_node(C, Node); + relation case_node(C,Node,usize, Node); + + conditional_node (c,n)<-- node(c, n), if c.hugr().get_optype(*n).is_conditional(); + case_node(c,cond,i, case) <-- conditional_node(c,cond), + for (i, case) in c.hugr().children(*cond).enumerate(), + if c.hugr().get_optype(case).is_case(); + + // inputs of conditional propagate into case nodes + out_wire_value(c, i_node, i_p, v) <-- + case_node(c, cond, case_index, case), + io_node(c, case, i_node, IO::Input), + in_wire_value(c, cond, cond_in_p, cond_in_v), + if let Some(conditional) = c.hugr().get_optype(*cond).as_conditional(), + let variant_len = conditional.sum_rows[*case_index].len(), + for (i_p, v) in utils::outputs_for_variant(*cond_in_p, *case_index, variant_len, cond_in_v); + + // outputs of case nodes propagate to outputs of conditional + out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(c, cond, _, case), + io_node(c, case, o, IO::Output), + in_wire_value(c, o, o_p, v); + + lattice case_reachable(C, Node, Node, bool); + case_reachable(c, cond, case, reachable) <-- case_node(c,cond,i,case), + in_wire_value(c, cond, IncomingPort::from(0), v), + let reachable = v.supports_tag(*i); + +} + +struct Machine<'a, H: HugrView> { + program: AscentProgram>, + cache: Arc>, +} + +impl<'a, H: HugrView> Machine<'a, H> { + pub fn new() -> Self { + Self { + program: Default::default(), + cache: ValueCache::new(), + } + } + + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + self.program.out_wire_value_proto.extend(wires.into_iter().map(|(w,v)| (w.node(), w.source(), v.into()))); + } + + pub fn run_hugr(&mut self, hugr: &'a H) -> ArcDataflowContext<'a, H> { + let context = ArcDataflowContext::new(hugr, self.cache.clone()); + self.program.context.push((context.clone(),)); + self.program.run(); + context + } + + pub fn read_out_wire_partial_value( + &self, + context: &ArcDataflowContext<'a, H>, + w: Wire, + ) -> Option { + self.program.out_wire_value.iter().find_map(|(c, n, p, v)| { + (c == context && &w.node() == n && &w.source() == p).then(|| v.clone().into()) + }) + } + + pub fn read_out_wire_value( + &self, + context: &ArcDataflowContext<'a, H>, + w: Wire, + ) -> Option { + // dbg!(&w); + let pv = self.read_out_wire_partial_value(context, w)?; + // dbg!(&pv); + let (_, typ) = context + .hugr() + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + pv.try_into_value(&typ).ok() + } + + pub fn tail_loop_terminates( + &self, + context: &ArcDataflowContext<'a, H>, + node: Node, + ) -> TailLoopTermination { + assert!(context.get_optype(node).is_tail_loop()); + self.program + .tail_loop_termination + .iter() + .find_map(|(c, n, v)| (c == context && n == &node).then_some(*v)) + .unwrap() + } + + pub fn case_reachable(&self, + context: &ArcDataflowContext<'a, H>, + case: Node,) -> bool { + assert!(context.get_optype(case).is_case()); + let cond = context.hugr().get_parent(case).unwrap(); + assert!(context.get_optype(cond).is_conditional()); + self.program.case_reachable.iter().find_map(|(c,cond2,case2,i)| (c == context && &cond == cond2 && &case == case2).then_some(*i)).unwrap() + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs new file mode 100644 index 000000000..1661fb082 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::atomic::AtomicUsize; +use std::sync::{Arc, Mutex}; + +use hugr_core::ops::Value; +use hugr_core::partial_value::{ValueHandle, ValueKey}; +use hugr_core::{HugrView, Node}; + +#[derive(Clone)] +pub struct ValueCache(HashMap>); + +impl ValueCache { + pub fn new() -> Arc> { + Arc::new(Mutex::new(Self::new_bare())) + } + + fn new_bare() -> Self { + Self(HashMap::new()) + } + + fn get(&mut self, key: ValueKey, value: &Value) -> ValueHandle { + let v = self + .0 + .entry(key.clone()) + .or_insert_with(|| value.clone().into()) + .clone(); + ValueHandle::new(key, v) + } +} + +static mut CONTEXT_ID: AtomicUsize = AtomicUsize::new(0); + +fn next_context_id() -> usize { + unsafe { CONTEXT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst) } +} + +pub struct DataflowContext<'a, H> { + id: usize, + hugr: &'a H, + cache: Arc>, +} + +impl<'a, H> DataflowContext<'a, H> { + fn new(hugr: &'a H, cache: Arc>) -> Self { + Self { + id: next_context_id(), + hugr, + cache, + } + } + + pub fn get_value_handle(&self, key: impl Into, value: &Value) -> ValueHandle { + let mut guard = self.cache.lock().unwrap(); + guard.get(key.into(), value) + } + + pub fn hugr(&self) -> &'a H { + self.hugr + } + + pub fn id(&self) -> usize { + self.id + } +} + +impl std::fmt::Debug for DataflowContext<'_, H> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DataflowContext({})", self.id) + } +} + +impl Hash for DataflowContext<'_, H> { + fn hash(&self, state: &mut I) { + self.id.hash(state); + } +} + +impl PartialEq for DataflowContext<'_, H> { + fn eq(&self, other: &usize) -> bool { + &self.id == other + } +} + +impl PartialEq for DataflowContext<'_, H> { + fn eq(&self, other: &Self) -> bool { + self == &other.id + } +} + +impl Eq for DataflowContext<'_, H> {} + +impl PartialOrd for DataflowContext<'_, H> { + fn partial_cmp(&self, _other: &Self) -> Option { + self.id.partial_cmp(&_other.id) + } +} + +impl<'a, H> Deref for DataflowContext<'a, H> { + type Target = H; + + fn deref(&self) -> &Self::Target { + self.hugr + } +} + +pub struct ArcDataflowContext<'a, H>(Arc>); + +impl<'a, H> ArcDataflowContext<'a, H> { + pub fn new(h: &'a H, cache: Arc>) -> Self { + Self(Arc::new(DataflowContext::new(h, cache))) + } +} + +impl<'a, H> Clone for ArcDataflowContext<'a, H> { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl<'a, H> Hash for ArcDataflowContext<'a, H> { + fn hash(&self, state: &mut HA) { + self.0.hash(state); + } +} + +impl<'a, H> PartialEq for ArcDataflowContext<'a, H> { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl<'a, H> Eq for ArcDataflowContext<'a, H> {} + +impl<'a, H> Deref for ArcDataflowContext<'a, H> { + type Target = DataflowContext<'a, H>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub trait DFContext: Clone + Eq + Hash { + type H: HugrView; + fn hugr(&self) -> &Self::H; + fn node_value_handle(&self, const_node: Node, value: &Value) -> ValueHandle; +} + +impl<'a, H: HugrView> DFContext for ArcDataflowContext<'a, H> { + type H = H; + fn hugr(&self) -> &Self::H { + self.0.hugr + } + + fn node_value_handle(&self, const_node: Node, value: &Value) -> ValueHandle { + self.0.get_value_handle(const_node, value) + } +} diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs new file mode 100644 index 000000000..e35ee0a47 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -0,0 +1,215 @@ +use hugr_core::{ + builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, partial_value::PartialSum, type_row, types::{FunctionType, SumType}, Extension +}; + +use hugr_core::partial_value::PartialValue; + +use super::*; + +#[test] +fn test_make_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let x = machine.read_out_wire_value(&c, v3).unwrap(); + assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); +} + +#[test] +fn test_unpack_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let [o1, o2] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v3]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let o1_r = machine.read_out_wire_value(&c, o1).unwrap(); + assert_eq!(o1_r, Value::false_val()); + let o2_r = machine.read_out_wire_value(&c, o2).unwrap(); + assert_eq!(o2_r, Value::true_val()); +} + +#[test] +fn test_unpack_const() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); + let [o] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let o_r = machine.read_out_wire_value(&c, o).unwrap(); + assert_eq!(o_r, Value::true_val()); +} + +#[test] +fn test_tail_loop_never_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_v = Value::unit_sum(3, 6).unwrap(); + let r_w = builder.add_load_value( + Value::sum( + 1, + [r_v.clone()], + SumType::new([type_row![], r_v.get_type().into()]), + ) + .unwrap(), + ); + let tlb = builder + .tail_loop_builder([], [], vec![r_v.get_type()].into()) + .unwrap(); + let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + let o_r = machine.read_out_wire_value(&c, tl_o).unwrap(); + assert_eq!(o_r, r_v); + assert_eq!( + TailLoopTermination::ExactlyZeroContinues, + machine.tail_loop_terminates(&c, tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_always_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_w = builder + .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let true_w = builder.add_load_value(Value::true_val()); + + let tlb = builder + .tail_loop_builder([], [(BOOL_T,true_w)], vec![BOOL_T].into()) + .unwrap(); + + // r_w has tag 0, so we always continue; + // we put true in our "other_output", but we should not propagate this to + // output because r_w never supports 1. + let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); + + let [tl_o1, tl_o2] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let o_r1 = machine.read_out_wire_partial_value(&c, tl_o1).unwrap(); + assert_eq!(o_r1, PartialValue::bottom()); + let o_r2 = machine.read_out_wire_partial_value(&c, tl_o2).unwrap(); + assert_eq!(o_r2, PartialValue::bottom()); + assert_eq!( + TailLoopTermination::bottom(), + machine.tail_loop_terminates(&c, tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_iterates_twice() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + // let r_w = builder + // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let tlb = builder + .tail_loop_builder([], [(BOOL_T, false_w), (BOOL_T, true_w)], vec![].into()) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().dataflow_signature().unwrap(), + FunctionType::new_endo(type_row![BOOL_T, BOOL_T]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); + + // let optype = builder.hugr().get_optype(tail_loop.node()); + // for p in builder.hugr().node_outputs(tail_loop.node()) { + // use hugr_core::ops::OpType; + // println!("{:?}, {:?}", p, optype.port_kind(p)); + + // } + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + // TODO once we can do conditionals put these wires inside `just_outputs` and + // we should be able to propagate their values + let [o_w1, o_w2, _] = tail_loop.outputs_arr(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + // TODO these hould be the propagated values for now they will bt join(true,false) + let o_r1 = machine.read_out_wire_partial_value(&c, o_w1).unwrap(); + // assert_eq!(o_r1, PartialValue::top()); + let o_r2 = machine.read_out_wire_partial_value(&c, o_w2).unwrap(); + // assert_eq!(o_r2, Value::true_val()); + assert_eq!( + TailLoopTermination::Top, + machine.tail_loop_terminates(&c, tail_loop.node()) + ) +} + +#[test] +fn conditional() { + let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; + let cond_t = Type::new_sum(variants.clone()); + let mut builder = DFGBuilder::new(FunctionType::new(Into::::into(cond_t),type_row![])).unwrap(); + let [arg_w] = builder.input_wires_arr(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut cond_builder = builder.conditional_builder((variants, arg_w), [(BOOL_T,true_w)], type_row!(BOOL_T,BOOL_T), ExtensionSet::default()).unwrap(); + // will be unreachable + let case1_b = cond_builder.case_builder(0).unwrap(); + let case1 = case1_b.finish_with_outputs([false_w,false_w]).unwrap(); + + let case2_b = cond_builder.case_builder(1).unwrap(); + let [c2a] = case2_b.input_wires_arr(); + let case2 = case2_b.finish_with_outputs([false_w,c2a]).unwrap(); + + let case3_b = cond_builder.case_builder(2).unwrap(); + let [c3_1,c3_2] = case3_b.input_wires_arr(); + let case3 = case3_b.finish_with_outputs([c3_1,false_w]).unwrap(); + + let cond = cond_builder.finish_sub_container().unwrap(); + + let [cond_o1,cond_o2] = cond.outputs_arr(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2,[PartialValue::variant(0,[])])); + machine.propolutate_out_wires([(arg_w, arg_pv)]); + let c = machine.run_hugr(&hugr); + + let cond_r1 = machine.read_out_wire_value(&c, cond_o1).unwrap(); + assert_eq!(cond_r1, Value::false_val()); + assert!(machine.read_out_wire_value(&c, cond_o2).is_none()); + + assert!(!machine.case_reachable(&c, case1.node())); + assert!(machine.case_reachable(&c, case2.node())); + assert!(machine.case_reachable(&c, case3.node())); +} diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs new file mode 100644 index 000000000..00162e73b --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -0,0 +1,407 @@ +// proptest-derive generates many of these warnings. +// https://github.com/rust-lang/rust/issues/120363 +// https://github.com/proptest-rs/proptest/issues/447 +#![cfg_attr(test, allow(non_local_definitions))] + +use std::{cmp::Ordering, ops::Index}; + +use ascent::lattice::{BoundedLattice, Lattice}; +use either::Either; +use hugr_core::{ + ops::OpTrait as _, + partial_value::{PartialValue, ValueHandle}, + types::{EdgeKind, TypeRow}, + HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, +}; +use itertools::zip_eq; + +#[cfg(test)] +use proptest_derive::Arbitrary; + +use super::context::DFContext; + +#[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] +pub struct PV(PartialValue); + +impl From for PV { + fn from(inner: PartialValue) -> Self { + Self(inner) + } +} + +impl PV { + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO the arguments here are not pretty, two usizes, better not mix them + /// up!!! + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + self.0.variant_field_value(variant, idx).into() + } + + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.supports_tag(tag) + } +} + +impl From for PartialValue { + fn from(value: PV) -> Self { + value.0 + } +} + +impl From for PV { + fn from(inner: ValueHandle) -> Self { + Self(inner.into()) + } +} + +impl Lattice for PV { + fn meet(self, other: Self) -> Self { + self.0.meet(other.0).into() + } + + fn meet_mut(&mut self, other: Self) -> bool { + self.0.meet_mut(other.0) + } + + fn join(self, other: Self) -> Self { + self.0.join(other.0).into() + } + + fn join_mut(&mut self, other: Self) -> bool { + self.0.join_mut(other.0) + } +} + +impl BoundedLattice for PV { + fn bottom() -> Self { + PartialValue::bottom().into() + } + + fn top() -> Self { + PartialValue::top().into() + } +} + +#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] +pub struct ValueRow(Vec); + +impl ValueRow { + fn new(len: usize) -> Self { + Self(vec![PV::bottom(); len]) + } + + fn singleton(len: usize, idx: usize, v: PV) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { + Self::singleton(r.len(), idx, v) + } + + fn bottom_from_row(r: &TypeRow) -> Self { + Self::new(r.len()) + } + + pub fn iter<'b>( + &'b self, + context: &'b impl DFContext, + n: Node, + ) -> impl Iterator + 'b { + zip_eq(value_inputs(context, n), self.0.iter()) + } + + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } +} + +impl Lattice for ValueRow { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PV; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow where + Vec: Index +{ + type Output = as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +pub(super) fn bottom_row(context: &impl DFContext, n: Node) -> ValueRow { + if let Some(sig) = context.hugr().signature(n) { + ValueRow::new(sig.input_count()) + } else { + ValueRow::new(0) + } +} + +pub(super) fn singleton_in_row( + context: &impl DFContext, + n: &Node, + ip: &IncomingPort, + v: PV, +) -> ValueRow { + let Some(sig) = context.hugr().signature(*n) else { + panic!("dougrulz"); + }; + if sig.input_count() <= ip.index() { + panic!( + "bad port index: {} >= {}: {}", + ip.index(), + sig.input_count(), + context.hugr().get_optype(*n).description() + ); + } + ValueRow::singleton_from_row(&context.hugr().signature(*n).unwrap().input, ip.index(), v) +} + +pub(super) fn partial_value_from_load_constant(context: &impl DFContext, node: Node) -> PV { + let load_op = context.hugr().get_optype(node).as_load_constant().unwrap(); + let const_node = context + .hugr() + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = context.hugr().get_optype(const_node).as_const().unwrap(); + context + .node_value_handle(const_node, const_op.value()) + .into() +} + +pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { + PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum IO { + Input, + Output, +} + +pub(super) fn value_inputs( + context: &impl DFContext, + n: Node, +) -> impl Iterator + '_ { + context.hugr().in_value_types(n).map(|x| x.0) +} + +pub(super) fn value_outputs( + context: &impl DFContext, + n: Node, +) -> impl Iterator + '_ { + context.hugr().out_value_types(n).map(|x| x.0) +} + +// We have several cases where sum types propagate to different places depending +// on their variant tag: +// - From the input of a conditional to the inputs of it's case nodes +// - From the input of the output node of a tail loop to the output of the input node of the tail loop +// - From the input of the output node of a tail loop to the output of tail loop node +// - From the input of a the output node of a dataflow block to the output of the input node of a dataflow block +// - From the input of a the output node of a dataflow block to the output of the cfg +// +// For a value `v` on an incoming porg `output_p`, compute the (out port,value) +// pairs that should be propagated for a given variant tag. We must also supply +// the length of this variant because it cannot always be deduced from the other +// inputs. +// +// If `v` does not support `variant_tag`, then all propagated values will be bottom.` +// +// If `output_p.index()` is 0 then the result is the contents of the variant. +// Otherwise, it is the single "other_output". +// +// TODO doctests +pub(super) fn outputs_for_variant<'a>( + output_p: IncomingPort, + variant_tag: usize, + variant_len: usize, + v: &'a PV, +) -> impl Iterator + 'a { + if output_p.index() == 0 { + Either::Left( + (0..variant_len) + .map(move |i| (i.into(), v.variant_field_value(variant_tag, i))), + ) + } else { + let v = if v.supports_tag(variant_tag) { + v.clone() + } else { + PV::bottom() + }; + Either::Right(std::iter::once(( + (variant_len + output_p.index() - 1).into(), + v, + ))) + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +#[cfg_attr(test, derive(Arbitrary))] +pub enum TailLoopTermination { + Bottom, + ExactlyZeroContinues, + Top, +} + +impl TailLoopTermination { + pub fn from_control_value(v: &PV) -> Self { + let (may_continue, may_break) = (v.supports_tag(0),v.supports_tag(1)); + if may_break && !may_continue { + Self::ExactlyZeroContinues + } else if may_break && may_continue { + Self::top() + } else { + Self::bottom() + } + } +} + +impl PartialOrd for TailLoopTermination { + fn partial_cmp(&self, other: &Self) -> Option { + if self == other { + return Some(std::cmp::Ordering::Equal); + }; + match (self, other) { + (Self::Bottom,_) => Some(Ordering::Less), + (_,Self::Bottom) => Some(Ordering::Greater), + (Self::Top,_) => Some(Ordering::Greater), + (_,Self::Top) => Some(Ordering::Less), + _ => None + } + } +} + +impl Lattice for TailLoopTermination { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn meet_mut(&mut self, other: Self) -> bool { + // let new_self = &mut self; + match (*self).partial_cmp(&other) { + Some(Ordering::Greater) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Bottom; + true + } + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match (*self).partial_cmp(&other) { + Some(Ordering::Less) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Top; + true + } + } + } +} + +impl BoundedLattice for TailLoopTermination { + fn bottom() -> Self { + Self::Bottom + + } + + fn top() -> Self { + Self::Top + } +} + +#[cfg(test)] +#[cfg_attr(test, allow(non_local_definitions))] +mod test { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn bounded_lattice(v: TailLoopTermination) { + prop_assert!(v <= TailLoopTermination::top()); + prop_assert!(v >= TailLoopTermination::bottom()); + } + + #[test] + fn meet_join_self_noop(v1: TailLoopTermination) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 803196144..eab13b21e 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod const_fold2; mod half_node; pub mod merge_bbs; pub mod nest_cfgs;