From f77f5ed531052c3b454f8eb24ea0ce0ca611fe71 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Wed, 8 May 2024 08:22:41 +0100 Subject: [PATCH] Rename {Const,Value}::const_type -> get_type. Add some useful functions for `ops::constant::ExtensionValue` --- hugr/src/algorithm/const_fold.rs | 4 +- hugr/src/builder/build_traits.rs | 2 +- hugr/src/builder/tail_loop.rs | 2 +- hugr/src/ops/constant.rs | 120 +++++++++++++++---------- hugr/src/ops/constant/custom.rs | 16 ++-- hugr/src/std_extensions/collections.rs | 2 +- hugr/src/std_extensions/logic.rs | 2 +- hugr/src/types/check.rs | 4 +- 8 files changed, 87 insertions(+), 65 deletions(-) diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index 7aef60c88..047665234 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -75,7 +75,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR /// Generate a graph that loads and outputs `consts` in order, validating /// against `reg`. fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { - let const_types = consts.iter().map(Value::const_type).collect_vec(); + let const_types = consts.iter().map(Value::get_type).collect_vec(); let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); let outputs = consts @@ -337,7 +337,7 @@ mod test { let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into(); let mut build = DFGBuilder::new(FunctionType::new( type_row![], - vec![list.const_type().clone()], + vec![list.get_type().clone()], )) .unwrap(); diff --git a/hugr/src/builder/build_traits.rs b/hugr/src/builder/build_traits.rs index 3554a87c5..e3ef10de8 100644 --- a/hugr/src/builder/build_traits.rs +++ b/hugr/src/builder/build_traits.rs @@ -358,7 +358,7 @@ pub trait Dataflow: Container { let load_n = self .add_dataflow_op( ops::LoadConstant { - datatype: op.const_type().clone(), + datatype: op.get_type().clone(), }, // Constant wire from the constant value node vec![Wire::new(const_node, OutgoingPort::from(0))], diff --git a/hugr/src/builder/tail_loop.rs b/hugr/src/builder/tail_loop.rs index 40baf8657..901324fe9 100644 --- a/hugr/src/builder/tail_loop.rs +++ b/hugr/src/builder/tail_loop.rs @@ -148,7 +148,7 @@ mod test { let const_wire = loop_b.add_load_const(Value::true_val()); let lift_node = loop_b.add_dataflow_op( ops::Lift { - type_row: vec![const_val.const_type().clone()].into(), + type_row: vec![const_val.get_type().clone()].into(), new_extension: PRELUDE_ID, }, [const_wire], diff --git a/hugr/src/ops/constant.rs b/hugr/src/ops/constant.rs index bb13a7317..01f9e46ea 100644 --- a/hugr/src/ops/constant.rs +++ b/hugr/src/ops/constant.rs @@ -8,6 +8,7 @@ use crate::extension::ExtensionSet; use crate::types::{CustomType, EdgeKind, FunctionType, SumType, SumTypeError, Type}; use crate::{Hugr, HugrView}; +use delegate::delegate; use itertools::Itertools; use smol_str::SmolStr; use thiserror::Error; @@ -35,9 +36,11 @@ impl Const { &self.value } - /// Returns a reference to the type of this constant. - pub fn const_type(&self) -> Type { - self.value.const_type() + delegate! { + to self.value { + /// Returns the type of this constant. + pub fn get_type(&self) -> Type; + } } } @@ -47,6 +50,34 @@ impl From for Const { } } +impl NamedOp for Const { + fn name(&self) -> OpName { + self.value().name() + } +} + +impl StaticTag for Const { + const TAG: OpTag = OpTag::Const; +} + +impl OpTrait for Const { + fn description(&self) -> &str { + "Constant value" + } + + fn extension_delta(&self) -> ExtensionSet { + self.value().extension_reqs() + } + + fn tag(&self) -> OpTag { + ::TAG + } + + fn static_output(&self) -> Option { + Some(EdgeKind::Const(self.get_type())) + } +} + impl From for Value { fn from(konst: Const) -> Self { konst.value @@ -96,16 +127,36 @@ pub enum Value { }, } -/// Boxed [`CustomConst`] trait object. +/// An opaque newtype awround a `Box`. /// -/// Use [`Value::extension`] to create a new variant of this type. -/// -/// This is required to avoid in -/// [`Value::Extension`], while implementing a transparent encoding into a -/// `CustomConst`. +/// This will be the serialisation barrier that ensures all implementors of +/// [`CustomConst`] are serialised through [`CustomSerialized`]. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[serde(transparent)] -pub struct ExtensionValue(pub(super) Box); +pub struct ExtensionValue(Box); + +impl ExtensionValue { + /// Create a new [`ExtensionValue`] from any [`CustomConst`]. + pub fn new(cc: impl CustomConst) -> Self { + Self(Box::new(cc)) + } + + /// Returns a reference to the internal [`CustomConst`]. + pub fn value(&self) -> &dyn CustomConst { + self.0.as_ref() + } + + delegate! { + to self.0 { + /// Returns the type of the internal [`CustomConst`]. + pub fn get_type(&self) -> Type; + /// An identifier of the internal [`CustomConst`]. + pub fn name(&self) -> ValueName; + /// The extension(s) defining the internal [`CustomConst`]. + pub fn extension_reqs(&self) -> ExtensionSet; + } + } +} impl PartialEq for ExtensionValue { fn eq(&self, other: &Self) -> bool { @@ -166,11 +217,11 @@ fn mono_fn_type(h: &Hugr) -> Result { } impl Value { - /// Returns a reference to the type of this [`Value`]. - pub fn const_type(&self) -> Type { + /// Returns the type of this [`Value`]. + pub fn get_type(&self) -> Type { match self { - Self::Extension { e } => e.0.get_type(), - Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::const_type).collect_vec()), + Self::Extension { e } => e.get_type(), + Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::get_type).collect_vec()), Self::Sum { sum_type, .. } => sum_type.clone().into(), Self::Function { hugr } => { let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e)); @@ -268,7 +319,7 @@ impl Value { fn name(&self) -> OpName { match self { - Self::Extension { e } => format!("const:custom:{}", e.0.name()), + Self::Extension { e } => format!("const:custom:{}", e.name()), Self::Function { hugr: h } => { let Some(t) = h.get_function_type() else { panic!("HUGR root node isn't a valid function parent."); @@ -289,7 +340,7 @@ impl Value { /// The extensions required by a [`Value`] pub fn extension_reqs(&self) -> ExtensionSet { match self { - Self::Extension { e } => e.0.extension_reqs().clone(), + Self::Extension { e } => e.extension_reqs().clone(), Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run) Self::Tuple { vs } => ExtensionSet::union_over(vs.iter().map(Value::extension_reqs)), Self::Sum { values, .. } => { @@ -299,35 +350,6 @@ impl Value { } } -impl NamedOp for Const { - fn name(&self) -> OpName { - self.value().name() - } -} - -impl StaticTag for Const { - const TAG: OpTag = OpTag::Const; -} -impl OpTrait for Const { - fn description(&self) -> &str { - "Constant value" - } - - fn extension_delta(&self) -> ExtensionSet { - self.value().extension_reqs() - } - - fn tag(&self) -> OpTag { - ::TAG - } - - fn static_output(&self) -> Option { - Some(EdgeKind::Const(self.const_type())) - } -} - -// [KnownTypeConst] is guaranteed to be the right type, so can be constructed -// without initial type check. impl From for Value where T: CustomConst, @@ -484,7 +506,7 @@ mod test { crate::extension::prelude::BOOL_T ])); - assert_eq!(v.const_type(), correct_type); + assert_eq!(v.get_type(), correct_type); assert!(v.name().starts_with("const:function:")) } @@ -508,7 +530,7 @@ mod test { #[case] expected_type: Type, #[case] name_prefix: &str, ) { - assert_eq!(const_value.const_type(), expected_type); + assert_eq!(const_value.get_type(), expected_type); let name = const_value.name(); assert!( name.starts_with(name_prefix), @@ -541,10 +563,10 @@ mod test { .into(); let classic_t = Type::new_extension(typ_int.clone()); assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq); - assert_eq!(yaml_const.const_type(), classic_t); + assert_eq!(yaml_const.get_type(), classic_t); let typ_qb = CustomType::new("my_type", vec![], ex_id, TypeBound::Eq); let t = Type::new_extension(typ_qb.clone()); - assert_ne!(yaml_const.const_type(), t); + assert_ne!(yaml_const.get_type(), t); } } diff --git a/hugr/src/ops/constant/custom.rs b/hugr/src/ops/constant/custom.rs index e69b62f36..f5937028a 100644 --- a/hugr/src/ops/constant/custom.rs +++ b/hugr/src/ops/constant/custom.rs @@ -35,7 +35,7 @@ pub trait CustomConst: /// [USize]: crate::extension::prelude::USIZE_T fn extension_reqs(&self) -> ExtensionSet; - /// Check the value is a valid instance of the provided type. + /// Check the value. fn validate(&self) -> Result<(), CustomCheckFailure> { Ok(()) } @@ -48,10 +48,16 @@ pub trait CustomConst: false } - /// report the type + /// Report the type. fn get_type(&self) -> Type; } +impl PartialEq for dyn CustomConst { + fn eq(&self, other: &Self) -> bool { + (*self).equal_consts(other) + } +} + /// Const equality for types that have PartialEq pub fn downcast_equal_consts( constant: &T, @@ -112,9 +118,3 @@ impl CustomConst for CustomSerialized { self.typ.clone() } } - -impl PartialEq for dyn CustomConst { - fn eq(&self, other: &Self) -> bool { - (*self).equal_consts(other) - } -} diff --git a/hugr/src/std_extensions/collections.rs b/hugr/src/std_extensions/collections.rs index 4896f1ab7..747841a6b 100644 --- a/hugr/src/std_extensions/collections.rs +++ b/hugr/src/std_extensions/collections.rs @@ -84,7 +84,7 @@ impl CustomConst for ListValue { // check all values are instances of the element type for v in &self.0 { - if v.const_type() != *ty { + if v.get_type() != *ty { return Err(error()); } } diff --git a/hugr/src/std_extensions/logic.rs b/hugr/src/std_extensions/logic.rs index 6978b2d0c..86de41b28 100644 --- a/hugr/src/std_extensions/logic.rs +++ b/hugr/src/std_extensions/logic.rs @@ -255,7 +255,7 @@ pub(crate) mod test { let true_val = r.get_value(&TRUE_NAME).unwrap(); for v in [false_val, true_val] { - let simpl = v.typed_value().const_type(); + let simpl = v.typed_value().get_type(); assert_eq!(simpl, BOOL_T); } } diff --git a/hugr/src/types/check.rs b/hugr/src/types/check.rs index 3de851521..174ff817d 100644 --- a/hugr/src/types/check.rs +++ b/hugr/src/types/check.rs @@ -10,7 +10,7 @@ use crate::ops::Value; #[non_exhaustive] pub enum SumTypeError { /// The type of the variant doesn't match the type of the value. - #[error("Expected type {expected} for element {index} of variant #{tag}, but found {}", .found.const_type())] + #[error("Expected type {expected} for element {index} of variant #{tag}, but found {}", .found.get_type())] InvalidValueType { /// Tag of the variant. tag: usize, @@ -70,7 +70,7 @@ impl super::SumType { } for (index, (t, v)) in itertools::zip_eq(variant.iter(), val.iter()).enumerate() { - if v.const_type() != *t { + if v.get_type() != *t { Err(SumTypeError::InvalidValueType { tag, index,