Skip to content

Commit

Permalink
feat!: CustomConst is not restricted to being CustomType (#878)
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q authored Mar 14, 2024
1 parent 6ff6c01 commit d5294ad
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 47 deletions.
8 changes: 4 additions & 4 deletions quantinuum-hugr/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ impl CustomConst for ConstUsize {
ExtensionSet::singleton(&PRELUDE_ID)
}

fn custom_type(&self) -> CustomType {
USIZE_CUSTOM_T
fn get_type(&self) -> Type {
USIZE_T
}
}

Expand Down Expand Up @@ -228,8 +228,8 @@ impl CustomConst for ConstError {
fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::singleton(&PRELUDE_ID)
}
fn custom_type(&self) -> CustomType {
ERROR_CUSTOM_TYPE
fn get_type(&self) -> Type {
ERROR_TYPE
}
}

Expand Down
7 changes: 3 additions & 4 deletions quantinuum-hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ where
T: CustomConst,
{
fn from(value: T) -> Self {
let typ = Type::new_extension(value.custom_type());
let typ = value.get_type();
Const {
value: Value::custom(value),
typ,
Expand Down Expand Up @@ -263,9 +263,8 @@ mod test {
assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq);
classic_t.check_type(&val).unwrap();

let typ_qb = CustomType::new("mytype", vec![], ex_id, TypeBound::Eq);
let t = Type::new_extension(typ_qb.clone());
assert_matches!(t.check_type(&val),
let typ_qb: Type = CustomType::new("mytype", vec![], ex_id, TypeBound::Eq).into();
assert_matches!(typ_qb.check_type(&val),
Err(ConstTypeError::CustomCheckFail(CustomCheckFailure::TypeMismatch{expected, found})) => expected == typ_int && found == typ_qb);

assert_eq!(val, val);
Expand Down
4 changes: 2 additions & 2 deletions quantinuum-hugr/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ impl CustomConst for ConstF64 {
format!("f64({})", self.value).into()
}

fn custom_type(&self) -> CustomType {
FLOAT64_CUSTOM_TYPE
fn get_type(&self) -> Type {
FLOAT64_TYPE
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
Expand Down
8 changes: 4 additions & 4 deletions quantinuum-hugr/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ impl CustomConst for ConstIntU {
ExtensionSet::singleton(&EXTENSION_ID)
}

fn custom_type(&self) -> CustomType {
int_custom_type(type_arg(self.log_width))
fn get_type(&self) -> Type {
int_type(type_arg(self.log_width))
}
}

Expand All @@ -175,8 +175,8 @@ impl CustomConst for ConstIntS {
ExtensionSet::singleton(&EXTENSION_ID)
}

fn custom_type(&self) -> CustomType {
int_custom_type(type_arg(self.log_width))
fn get_type(&self) -> Type {
int_type(type_arg(self.log_width))
}
}

Expand Down
31 changes: 18 additions & 13 deletions quantinuum-hugr/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ impl ListValue {
pub fn new_empty(typ: Type) -> Self {
Self(vec![], typ)
}

/// Returns the type of the `[ListValue]` as a `[CustomType]`.`
pub fn custom_type(&self) -> CustomType {
list_custom_type(self.1.clone())
}
}

#[typetag::serde]
Expand All @@ -53,11 +58,8 @@ impl CustomConst for ListValue {
SmolStr::new_inline("list")
}

fn custom_type(&self) -> CustomType {
let list_type_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
list_type_def
.instantiate(vec![Into::<TypeArg>::into(self.1.clone())])
.unwrap()
fn get_type(&self) -> Type {
self.custom_type().into()
}

fn validate(&self) -> Result<(), CustomCheckFailure> {
Expand Down Expand Up @@ -176,15 +178,18 @@ lazy_static! {
pub static ref EXTENSION: Extension = extension();
}

/// Get the type of a list of `elem_type`
/// Get the type of a list of `elem_type` as a `CustomType`.
pub fn list_custom_type(elem_type: Type) -> CustomType {
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.instantiate(vec![TypeArg::Type { ty: elem_type }])
.unwrap()
}

/// Get the `Type` of a list of `elem_type`.
pub fn list_type(elem_type: Type) -> Type {
Type::new_extension(
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.instantiate(vec![TypeArg::Type { ty: elem_type }])
.unwrap(),
)
list_custom_type(elem_type).into()
}

fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) {
Expand Down
6 changes: 3 additions & 3 deletions quantinuum-hugr/src/types/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub enum CustomCheckFailure {
/// The expected custom type.
expected: CustomType,
/// The custom type found when checking.
found: CustomType,
found: Type,
},
/// Any other message
#[error("{0}")]
Expand Down Expand Up @@ -107,8 +107,8 @@ impl Type {
pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> {
match (&self.0, val) {
(TypeEnum::Extension(expected), Value::Extension { c: (e_val,) }) => {
let found = e_val.custom_type();
if found == *expected {
let found = e_val.get_type();
if found == expected.clone().into() {
Ok(e_val.validate()?)
} else {
Err(CustomCheckFailure::TypeMismatch {
Expand Down
8 changes: 7 additions & 1 deletion quantinuum-hugr/src/types/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use std::fmt::{self, Display};

use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDef};

use super::TypeName;
use super::{
type_param::{TypeArg, TypeParam},
Substitution, TypeBound,
};
use super::{Type, TypeName};

/// An opaque type element. Contains the unique identifier of its definition.
#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -131,3 +131,9 @@ impl Display for CustomType {
}
}
}

impl From<CustomType> for Type {
fn from(value: CustomType) -> Self {
Self::new_extension(value)
}
}
36 changes: 20 additions & 16 deletions quantinuum-hugr/src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ use crate::macros::impl_box_clone;

use crate::{Hugr, HugrView};

use crate::types::{CustomCheckFailure, CustomType};
use crate::types::{CustomCheckFailure, Type};

/// A value that can be stored as a static constant. Representing core types and
/// extension types.
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "v")]
pub enum Value {
/// An extension constant value, that can check it is of a given [CustomType].
///
/// An extension constant value.
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Extension {
#[allow(missing_docs)]
Expand Down Expand Up @@ -139,10 +138,12 @@ impl<T: CustomConst> From<T> for Value {
}
}

/// Constant value for opaque [`CustomType`]s.
/// Constant value for opaque `[CustomType]`s.
///
/// When implementing this trait, include the `#[typetag::serde]` attribute to
/// enable serialization.
///
/// [CustomType]: crate::types::CustomType
#[typetag::serde(tag = "c")]
pub trait CustomConst:
Send + Sync + std::fmt::Debug + CustomConstBoxClone + Any + Downcast
Expand Down Expand Up @@ -170,7 +171,7 @@ pub trait CustomConst:
}

/// report the type
fn custom_type(&self) -> CustomType;
fn get_type(&self) -> Type;
}

/// Const equality for types that have PartialEq
Expand All @@ -191,19 +192,22 @@ impl_box_clone!(CustomConst, CustomConstBoxClone);
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// A value stored as a serialized blob that can report its own type.
pub struct CustomSerialized {
typ: CustomType,
typ: Type,
value: serde_yaml::Value,
extensions: ExtensionSet,
}

impl CustomSerialized {
/// Creates a new [`CustomSerialized`].
pub fn new(typ: CustomType, value: serde_yaml::Value, exts: impl Into<ExtensionSet>) -> Self {
let extensions = exts.into();
pub fn new(
typ: impl Into<Type>,
value: serde_yaml::Value,
exts: impl Into<ExtensionSet>,
) -> Self {
Self {
typ,
typ: typ.into(),
value,
extensions,
extensions: exts.into(),
}
}

Expand All @@ -226,7 +230,7 @@ impl CustomConst for CustomSerialized {
fn extension_reqs(&self) -> ExtensionSet {
self.extensions.clone()
}
fn custom_type(&self) -> CustomType {
fn get_type(&self) -> Type {
self.typ.clone()
}
}
Expand All @@ -244,9 +248,9 @@ pub(crate) mod test {
use super::*;
use crate::builder::test::simple_dfg_hugr;
use crate::ops::Const;
use crate::std_extensions::arithmetic::float_types::{self, FLOAT64_CUSTOM_TYPE};
use crate::std_extensions::arithmetic::float_types::{self, FLOAT64_TYPE};
use crate::type_row;
use crate::types::{FunctionType, Type};
use crate::types::{CustomType, FunctionType, Type};

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]

Expand All @@ -262,14 +266,14 @@ pub(crate) mod test {
ExtensionSet::singleton(self.0.extension())
}

fn custom_type(&self) -> CustomType {
self.0.clone()
fn get_type(&self) -> Type {
self.0.clone().into()
}
}

pub(crate) fn serialized_float(f: f64) -> Const {
CustomSerialized {
typ: FLOAT64_CUSTOM_TYPE,
typ: FLOAT64_TYPE,
value: serde_yaml::Value::Number(f.into()),
extensions: float_types::EXTENSION_ID.into(),
}
Expand Down

0 comments on commit d5294ad

Please sign in to comment.