diff --git a/quantinuum-hugr/src/extension/prelude.rs b/quantinuum-hugr/src/extension/prelude.rs index 1cdc8b731..d6bf068d3 100644 --- a/quantinuum-hugr/src/extension/prelude.rs +++ b/quantinuum-hugr/src/extension/prelude.rs @@ -191,8 +191,8 @@ impl CustomConst for ConstUsize { ExtensionSet::singleton(&PRELUDE_ID) } - fn custom_type(&self) -> CustomType { - USIZE_CUSTOM_T + fn get_type(&self) -> Type { + USIZE_T } } @@ -228,8 +228,8 @@ impl CustomConst for ConstError { fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::singleton(&PRELUDE_ID) } - fn custom_type(&self) -> CustomType { - ERROR_CUSTOM_TYPE + fn get_type(&self) -> Type { + ERROR_TYPE } } diff --git a/quantinuum-hugr/src/ops/constant.rs b/quantinuum-hugr/src/ops/constant.rs index f7fe8562e..7c9d94488 100644 --- a/quantinuum-hugr/src/ops/constant.rs +++ b/quantinuum-hugr/src/ops/constant.rs @@ -129,7 +129,7 @@ where T: CustomConst, { fn from(value: T) -> Self { - let typ = Type::new_extension(value.custom_type()); + let typ = value.get_type(); Const { value: Value::custom(value), typ, @@ -263,9 +263,8 @@ mod test { assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq); classic_t.check_type(&val).unwrap(); - let typ_qb = CustomType::new("mytype", vec![], ex_id, TypeBound::Eq); - let t = Type::new_extension(typ_qb.clone()); - assert_matches!(t.check_type(&val), + let typ_qb: Type = CustomType::new("mytype", vec![], ex_id, TypeBound::Eq).into(); + assert_matches!(typ_qb.check_type(&val), Err(ConstTypeError::CustomCheckFail(CustomCheckFailure::TypeMismatch{expected, found})) => expected == typ_int && found == typ_qb); assert_eq!(val, val); diff --git a/quantinuum-hugr/src/std_extensions/arithmetic/float_types.rs b/quantinuum-hugr/src/std_extensions/arithmetic/float_types.rs index d78ed4ab2..3bc15dde9 100644 --- a/quantinuum-hugr/src/std_extensions/arithmetic/float_types.rs +++ b/quantinuum-hugr/src/std_extensions/arithmetic/float_types.rs @@ -56,8 +56,8 @@ impl CustomConst for ConstF64 { format!("f64({})", self.value).into() } - fn custom_type(&self) -> CustomType { - FLOAT64_CUSTOM_TYPE + fn get_type(&self) -> Type { + FLOAT64_TYPE } fn equal_consts(&self, other: &dyn CustomConst) -> bool { diff --git a/quantinuum-hugr/src/std_extensions/arithmetic/int_types.rs b/quantinuum-hugr/src/std_extensions/arithmetic/int_types.rs index 9af7e9cb8..d7501f5a3 100644 --- a/quantinuum-hugr/src/std_extensions/arithmetic/int_types.rs +++ b/quantinuum-hugr/src/std_extensions/arithmetic/int_types.rs @@ -157,8 +157,8 @@ impl CustomConst for ConstIntU { ExtensionSet::singleton(&EXTENSION_ID) } - fn custom_type(&self) -> CustomType { - int_custom_type(type_arg(self.log_width)) + fn get_type(&self) -> Type { + int_type(type_arg(self.log_width)) } } @@ -175,8 +175,8 @@ impl CustomConst for ConstIntS { ExtensionSet::singleton(&EXTENSION_ID) } - fn custom_type(&self) -> CustomType { - int_custom_type(type_arg(self.log_width)) + fn get_type(&self) -> Type { + int_type(type_arg(self.log_width)) } } diff --git a/quantinuum-hugr/src/std_extensions/collections.rs b/quantinuum-hugr/src/std_extensions/collections.rs index 3cd89c2bb..2c078a463 100644 --- a/quantinuum-hugr/src/std_extensions/collections.rs +++ b/quantinuum-hugr/src/std_extensions/collections.rs @@ -45,6 +45,11 @@ impl ListValue { pub fn new_empty(typ: Type) -> Self { Self(vec![], typ) } + + /// Returns the type of the `[ListValue]` as a `[CustomType]`.` + pub fn custom_type(&self) -> CustomType { + list_custom_type(self.1.clone()) + } } #[typetag::serde] @@ -53,11 +58,8 @@ impl CustomConst for ListValue { SmolStr::new_inline("list") } - fn custom_type(&self) -> CustomType { - let list_type_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); - list_type_def - .instantiate(vec![Into::::into(self.1.clone())]) - .unwrap() + fn get_type(&self) -> Type { + self.custom_type().into() } fn validate(&self) -> Result<(), CustomCheckFailure> { @@ -176,15 +178,18 @@ lazy_static! { pub static ref EXTENSION: Extension = extension(); } -/// Get the type of a list of `elem_type` +/// Get the type of a list of `elem_type` as a `CustomType`. +pub fn list_custom_type(elem_type: Type) -> CustomType { + EXTENSION + .get_type(&LIST_TYPENAME) + .unwrap() + .instantiate(vec![TypeArg::Type { ty: elem_type }]) + .unwrap() +} + +/// Get the `Type` of a list of `elem_type`. pub fn list_type(elem_type: Type) -> Type { - Type::new_extension( - EXTENSION - .get_type(&LIST_TYPENAME) - .unwrap() - .instantiate(vec![TypeArg::Type { ty: elem_type }]) - .unwrap(), - ) + list_custom_type(elem_type).into() } fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) { diff --git a/quantinuum-hugr/src/types/check.rs b/quantinuum-hugr/src/types/check.rs index 987a872ba..2a97f61d7 100644 --- a/quantinuum-hugr/src/types/check.rs +++ b/quantinuum-hugr/src/types/check.rs @@ -19,7 +19,7 @@ pub enum CustomCheckFailure { /// The expected custom type. expected: CustomType, /// The custom type found when checking. - found: CustomType, + found: Type, }, /// Any other message #[error("{0}")] @@ -107,8 +107,8 @@ impl Type { pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> { match (&self.0, val) { (TypeEnum::Extension(expected), Value::Extension { c: (e_val,) }) => { - let found = e_val.custom_type(); - if found == *expected { + let found = e_val.get_type(); + if found == expected.clone().into() { Ok(e_val.validate()?) } else { Err(CustomCheckFailure::TypeMismatch { diff --git a/quantinuum-hugr/src/types/custom.rs b/quantinuum-hugr/src/types/custom.rs index 50e112502..ba738227a 100644 --- a/quantinuum-hugr/src/types/custom.rs +++ b/quantinuum-hugr/src/types/custom.rs @@ -5,11 +5,11 @@ use std::fmt::{self, Display}; use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDef}; -use super::TypeName; use super::{ type_param::{TypeArg, TypeParam}, Substitution, TypeBound, }; +use super::{Type, TypeName}; /// An opaque type element. Contains the unique identifier of its definition. #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] @@ -131,3 +131,9 @@ impl Display for CustomType { } } } + +impl From for Type { + fn from(value: CustomType) -> Self { + Self::new_extension(value) + } +} diff --git a/quantinuum-hugr/src/values.rs b/quantinuum-hugr/src/values.rs index cdc7058bc..8a8f2e55f 100644 --- a/quantinuum-hugr/src/values.rs +++ b/quantinuum-hugr/src/values.rs @@ -14,15 +14,14 @@ use crate::macros::impl_box_clone; use crate::{Hugr, HugrView}; -use crate::types::{CustomCheckFailure, CustomType}; +use crate::types::{CustomCheckFailure, Type}; /// A value that can be stored as a static constant. Representing core types and /// extension types. #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] #[serde(tag = "v")] pub enum Value { - /// An extension constant value, that can check it is of a given [CustomType]. - /// + /// An extension constant value. // Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808 Extension { #[allow(missing_docs)] @@ -139,10 +138,12 @@ impl From for Value { } } -/// Constant value for opaque [`CustomType`]s. +/// Constant value for opaque `[CustomType]`s. /// /// When implementing this trait, include the `#[typetag::serde]` attribute to /// enable serialization. +/// +/// [CustomType]: crate::types::CustomType #[typetag::serde(tag = "c")] pub trait CustomConst: Send + Sync + std::fmt::Debug + CustomConstBoxClone + Any + Downcast @@ -170,7 +171,7 @@ pub trait CustomConst: } /// report the type - fn custom_type(&self) -> CustomType; + fn get_type(&self) -> Type; } /// Const equality for types that have PartialEq @@ -191,19 +192,22 @@ impl_box_clone!(CustomConst, CustomConstBoxClone); #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// A value stored as a serialized blob that can report its own type. pub struct CustomSerialized { - typ: CustomType, + typ: Type, value: serde_yaml::Value, extensions: ExtensionSet, } impl CustomSerialized { /// Creates a new [`CustomSerialized`]. - pub fn new(typ: CustomType, value: serde_yaml::Value, exts: impl Into) -> Self { - let extensions = exts.into(); + pub fn new( + typ: impl Into, + value: serde_yaml::Value, + exts: impl Into, + ) -> Self { Self { - typ, + typ: typ.into(), value, - extensions, + extensions: exts.into(), } } @@ -226,7 +230,7 @@ impl CustomConst for CustomSerialized { fn extension_reqs(&self) -> ExtensionSet { self.extensions.clone() } - fn custom_type(&self) -> CustomType { + fn get_type(&self) -> Type { self.typ.clone() } } @@ -244,9 +248,9 @@ pub(crate) mod test { use super::*; use crate::builder::test::simple_dfg_hugr; use crate::ops::Const; - use crate::std_extensions::arithmetic::float_types::{self, FLOAT64_CUSTOM_TYPE}; + use crate::std_extensions::arithmetic::float_types::{self, FLOAT64_TYPE}; use crate::type_row; - use crate::types::{FunctionType, Type}; + use crate::types::{CustomType, FunctionType, Type}; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -262,14 +266,14 @@ pub(crate) mod test { ExtensionSet::singleton(self.0.extension()) } - fn custom_type(&self) -> CustomType { - self.0.clone() + fn get_type(&self) -> Type { + self.0.clone().into() } } pub(crate) fn serialized_float(f: f64) -> Const { CustomSerialized { - typ: FLOAT64_CUSTOM_TYPE, + typ: FLOAT64_TYPE, value: serde_yaml::Value::Number(f.into()), extensions: float_types::EXTENSION_ID.into(), }