diff --git a/hugr-core/src/ops/module.rs b/hugr-core/src/ops/module.rs index 40c189b68..fef5105fb 100644 --- a/hugr-core/src/ops/module.rs +++ b/hugr-core/src/ops/module.rs @@ -128,7 +128,7 @@ impl OpTrait for AliasDefn { } /// A type alias declaration. Resolved at link time. -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] pub struct AliasDecl { /// Alias name diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index cb8cd0706..dd1cd9b29 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -118,7 +118,7 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> Ty .into_inner() } -#[derive(Clone, PartialEq, Debug, Eq, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)] #[serde(tag = "s")] #[non_exhaustive] /// Representation of a Sum type. @@ -196,7 +196,7 @@ impl From for Type { } } -#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)] +#[derive(Clone, PartialEq, Debug, Eq, Hash, derive_more::Display)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] /// Core types pub enum TypeEnum { @@ -249,7 +249,7 @@ impl TypeEnum { } #[derive( - Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, + Clone, PartialEq, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, )] #[display(fmt = "{}", "_0")] #[serde(into = "serialize::SerSimpleType", from = "serialize::SerSimpleType")] diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 93d2e506a..c3c4c3f7e 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -12,7 +12,7 @@ use super::{ use super::{Type, TypeName}; /// An opaque type element. Contains the unique identifier of its definition. -#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] pub struct CustomType { extension: ExtensionId, /// Unique identifier of the opaque type. diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 1cf731eda..542b3fc0b 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -14,7 +14,7 @@ use crate::{Direction, IncomingPort, OutgoingPort, Port}; #[cfg(test)] use {crate::proptest::RecursionDepth, ::proptest::prelude::*, proptest_derive::Arbitrary}; -#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] /// Describes the edges required to/from a node, and thus, also the type of a [Graph]. /// This includes both the concept of "signature" in the spec, diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index e7019fe7c..572ac1220 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -19,7 +19,7 @@ use super::{check_typevar_decl, CustomType, Substitution, Type, TypeBound}; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] // A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid) #[derive( - Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Deserialize, serde::Serialize, + Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[display(fmt = "{}", "_0.map(|i|i.to_string()).unwrap_or(\"-\".to_string())")] #[cfg_attr(test, derive(Arbitrary))] @@ -52,7 +52,7 @@ impl UpperBound { /// [PolyFuncType]: super::PolyFuncType /// [OpDef]: crate::extension::OpDef #[derive( - Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Deserialize, serde::Serialize, + Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[non_exhaustive] #[serde(tag = "tp")] @@ -142,7 +142,7 @@ impl From for TypeParam { } /// A statically-known argument value to an operation. -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] #[non_exhaustive] #[serde(tag = "tya")] pub enum TypeArg { @@ -214,7 +214,7 @@ impl From for TypeArg { } /// Variable in a TypeArg, that is not a [TypeArg::Type] or [TypeArg::Extensions], -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] pub struct TypeArgVariable { idx: usize, cached_decl: TypeParam, @@ -352,7 +352,7 @@ impl TypeArgVariable { /// A serialized representation of a value of a [CustomType] /// restricted to equatable types. -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] pub struct CustomTypeArg { /// The type of the constant. /// (Exact matches only - the constant is exactly this type.) diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index 5be3ddc2f..17182f788 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -16,7 +16,7 @@ use delegate::delegate; use itertools::Itertools; /// List of types, used for function signatures. -#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, Eq, Debug, Hash, serde::Serialize, serde::Deserialize)] #[non_exhaustive] #[serde(transparent)] pub struct TypeRow { diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 240468f4c..4dc128a78 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -18,6 +18,7 @@ itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } +ascent = "0.6.0" [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs new file mode 100644 index 000000000..f01a78c6e --- /dev/null +++ b/hugr-passes/src/const_fold2.rs @@ -0,0 +1 @@ +mod datalog; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs new file mode 100644 index 000000000..b2e717108 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -0,0 +1,237 @@ +use std::hash::{Hash, Hasher}; + +use ascent::{ascent_run, Lattice}; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; +use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::types::{SumType, Type, TypeRow}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use itertools::{zip_eq, Itertools}; +use std::collections::HashMap; + +#[derive(PartialEq, Clone, Eq)] +struct HashableHashMap(HashMap); + +impl Hash for HashableHashMap { + fn hash(&self, state: &mut H) { + self.0.keys().for_each(|k| k.hash(state)); + self.0.values().for_each(|v| v.hash(state)); + } +} + +#[derive(PartialEq, Clone, Eq, Hash)] +enum PartialValue { + Bottom(Type), + Value(Node, Type), + PartialSum(HashableHashMap>, SumType), + Top(Type), +} + +impl PartialValue { + fn get_type(&self) -> Type { + match self { + PartialValue::Bottom(t) => t.clone(), + PartialValue::Value(_, t) => t.clone(), + PartialValue::PartialSum(_, t) => t.clone().into(), + PartialValue::Top(t) => t.clone(), + } + } + + fn top_from_hugr(hugr: &impl HugrView, node: Node, port: OutgoingPort) -> Self { + Self::Top( + hugr.signature(node) + .unwrap() + .out_port_type(port) + .unwrap() + .clone(), + ) + } + + fn from_load_constant(hugr: &impl HugrView, node: Node) -> Self { + let load_op = hugr.get_optype(node).as_load_constant().unwrap(); + let const_node = hugr + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = hugr.get_optype(const_node).as_const().unwrap(); + Self::Value(const_node, const_op.get_type()) + } + + fn tuple_from_value_row(r: &ValueRow) -> Self { + unimplemented!() + } + +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + // TODO we can do better + (self == other).then_some(std::cmp::Ordering::Equal) + } +} + +impl Lattice for PartialValue { + fn meet(self, _other: Self) -> Self { + // should not be required + todo!() + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + debug_assert_eq!(self.get_type(), other.get_type()); + match (self, other) { + (Self::Bottom(_), _) => false, + (s, rhs @ Self::Bottom(_)) => { + *s = rhs; + true + } + (_, Self::Top(_)) => false, + (s @ Self::Top(_), x) => { + *s = x; + true + } + (Self::Value(n1, t), Self::Value(n2, _)) if n1 == &n2 => false, + ( + Self::PartialSum(HashableHashMap(hm1), t), + Self::PartialSum(HashableHashMap(hm2), _), + ) => { + let mut changed = false; + for (k, v) in hm2 { + let row = hm1.entry(k).or_insert_with(|| { + changed = true; + t.get_variant(k) + .unwrap() + .iter() + .cloned() + .map(Self::Top) + .collect_vec() + }); + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } + changed + } + (s, _) => { + *s = Self::Bottom(s.get_type()); + true + } + } + } +} + +// fn input_row<'a>(inp: impl Iterator) -> impl Iterator { +// todo!() +// } + +#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] +enum ValueRow { + Values(Vec), + Bottom, +} + +impl ValueRow { + fn into_partial_value(self) -> PartialValue { + todo!() + } + + fn new(tr: &TypeRow) -> Self { + Self::Values(tr.iter().cloned().map(PartialValue::Top).collect_vec()) + } + + fn singleton(tr: &TypeRow, idx: usize, v: PartialValue) -> Self { + let mut r = Self::new(tr); + if let Self::Values(vec) = &mut r { + vec[idx] = v; + } + r + } + + fn iter(&self) -> impl Iterator { + std::iter::empty() + } +} + +impl Lattice for ValueRow { + fn meet(self, other: Self) -> Self { + todo!() + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + match (self, other) { + (Self::Bottom, _) => false, + (s, o @ Self::Bottom) => { + *s = o; + true + } + (s, Self::Values(vs2)) => { + let (b, r) = if let Self::Values(vs1) = s { + if vs1.len() == vs2.len() { + let mut changed = false; + for (v1, v2) in zip_eq(vs1.iter_mut(), vs2.into_iter()) { + changed |= v1.join_mut(v2); + } + (false, changed) + } else { + (true, true) + } + } else { + panic!("impossible") + }; + if b { + *s = Self::Bottom; + } + r + } + } + } +} + +fn node_in_value_row<'a>( + ins: impl Iterator, +) -> impl Iterator { + std::iter::empty() +} + +fn tc(hugr: &impl HugrView, node: Node) { + assert!(OpTag::DataflowParent.is_superset(hugr.get_optype(node).tag())); + let d = DescendantsGraph::<'_, Node>::try_new(hugr, node).unwrap(); + ascent_run! { + relation node(Node) = d.nodes().map(|x| (x,)).collect_vec(); + + relation in_wire(Node, IncomingPort); + in_wire(n,p) <-- node(n), for p in d.node_inputs(*n); + + relation out_wire(Node, OutgoingPort); + out_wire(n,p) <-- node(n), for p in d.node_outputs(*n); + + lattice node_in_value_row(Node, ValueRow); + node_in_value_row(n, ValueRow::new(&hugr.signature(*n).unwrap().input)) <-- node(n); + + lattice out_wire_value(Node, OutgoingPort, PartialValue); + out_wire_value(n,p, PartialValue::top_from_hugr(hugr,*n,*p)) <-- out_wire(n,p); + + node_in_value_row(n,ValueRow::singleton(&hugr.signature(*n).unwrap().input, ip.index(), v.clone())) <-- in_wire(n, ip), + if let Some((m,op)) = hugr.single_linked_output(*n, *ip), out_wire_value(m, op, ?v); + + lattice in_wire_value(Node, IncomingPort, PartialValue); + in_wire_value(n,p,v) <-- node_in_value_row(n, ?vr), for (p,v) in vr.iter(); + + relation load_constant_node(Node); + load_constant_node(n) <-- node(n), if hugr.get_optype(*n).is_load_constant(); + out_wire_value(n, 0.into(), PartialValue::from_load_constant(hugr, *n)) <-- load_constant_node(n); + + relation make_tuple_node(Node); + make_tuple_node(n) <-- node(n), if hugr.get_optype(*n).is_make_tuple(); + + out_wire_value(n,0.into(), PartialValue::tuple_from_value_row(vs)) <-- make_tuple_node(n), node_in_value_row(n, ?vs); + }; +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 803196144..eab13b21e 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod const_fold2; mod half_node; pub mod merge_bbs; pub mod nest_cfgs;