Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: CustomConst is not restricted to being CustomType #878

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions quantinuum-hugr/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ impl CustomConst for ConstUsize {
ExtensionSet::singleton(&PRELUDE_ID)
}

fn custom_type(&self) -> CustomType {
USIZE_CUSTOM_T
fn get_type(&self) -> Type {
USIZE_T
}
}

Expand Down Expand Up @@ -228,8 +228,8 @@ impl CustomConst for ConstError {
fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::singleton(&PRELUDE_ID)
}
fn custom_type(&self) -> CustomType {
ERROR_CUSTOM_TYPE
fn get_type(&self) -> Type {
ERROR_TYPE
}
}

Expand Down
7 changes: 3 additions & 4 deletions quantinuum-hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ where
T: CustomConst,
{
fn from(value: T) -> Self {
let typ = Type::new_extension(value.custom_type());
let typ = value.get_type();
Const {
value: Value::custom(value),
typ,
Expand Down Expand Up @@ -263,9 +263,8 @@ mod test {
assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq);
classic_t.check_type(&val).unwrap();

let typ_qb = CustomType::new("mytype", vec![], ex_id, TypeBound::Eq);
let t = Type::new_extension(typ_qb.clone());
assert_matches!(t.check_type(&val),
let typ_qb: Type = CustomType::new("mytype", vec![], ex_id, TypeBound::Eq).into();
assert_matches!(typ_qb.check_type(&val),
Err(ConstTypeError::CustomCheckFail(CustomCheckFailure::TypeMismatch{expected, found})) => expected == typ_int && found == typ_qb);

assert_eq!(val, val);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ impl CustomConst for ConstF64 {
format!("f64({})", self.value).into()
}

fn custom_type(&self) -> CustomType {
FLOAT64_CUSTOM_TYPE
fn get_type(&self) -> Type {
FLOAT64_TYPE
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
Expand Down
8 changes: 4 additions & 4 deletions quantinuum-hugr/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ impl CustomConst for ConstIntU {
ExtensionSet::singleton(&EXTENSION_ID)
}

fn custom_type(&self) -> CustomType {
int_custom_type(type_arg(self.log_width))
fn get_type(&self) -> Type {
int_type(type_arg(self.log_width))
}
}

Expand All @@ -175,8 +175,8 @@ impl CustomConst for ConstIntS {
ExtensionSet::singleton(&EXTENSION_ID)
}

fn custom_type(&self) -> CustomType {
int_custom_type(type_arg(self.log_width))
fn get_type(&self) -> Type {
int_type(type_arg(self.log_width))
}
}

Expand Down
31 changes: 18 additions & 13 deletions quantinuum-hugr/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ impl ListValue {
pub fn new_empty(typ: Type) -> Self {
Self(vec![], typ)
}

/// Returns the type of the `[ListValue]` as a `[CustomType]`.`
pub fn custom_type(&self) -> CustomType {
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
list_custom_type(self.1.clone())
}
}

#[typetag::serde]
Expand All @@ -53,11 +58,8 @@ impl CustomConst for ListValue {
SmolStr::new_inline("list")
}

fn custom_type(&self) -> CustomType {
let list_type_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
list_type_def
.instantiate(vec![Into::<TypeArg>::into(self.1.clone())])
.unwrap()
fn get_type(&self) -> Type {
self.custom_type().into()
}

fn validate(&self) -> Result<(), CustomCheckFailure> {
Expand Down Expand Up @@ -176,15 +178,18 @@ lazy_static! {
pub static ref EXTENSION: Extension = extension();
}

/// Get the type of a list of `elem_type`
/// Get the type of a list of `elem_type` as a `CustomType`.
pub fn list_custom_type(elem_type: Type) -> CustomType {
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.instantiate(vec![TypeArg::Type { ty: elem_type }])
.unwrap()
}

/// Get the `Type` of a list of `elem_type`.
pub fn list_type(elem_type: Type) -> Type {
Type::new_extension(
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.instantiate(vec![TypeArg::Type { ty: elem_type }])
.unwrap(),
)
list_custom_type(elem_type).into()
}

fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) {
Expand Down
6 changes: 3 additions & 3 deletions quantinuum-hugr/src/types/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub enum CustomCheckFailure {
/// The expected custom type.
expected: CustomType,
/// The custom type found when checking.
found: CustomType,
found: Type,
},
/// Any other message
#[error("{0}")]
Expand Down Expand Up @@ -107,8 +107,8 @@ impl Type {
pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> {
match (&self.0, val) {
(TypeEnum::Extension(expected), Value::Extension { c: (e_val,) }) => {
let found = e_val.custom_type();
if found == *expected {
let found = e_val.get_type();
if found == expected.clone().into() {
Ok(e_val.validate()?)
} else {
Err(CustomCheckFailure::TypeMismatch {
Expand Down
8 changes: 7 additions & 1 deletion quantinuum-hugr/src/types/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use std::fmt::{self, Display};

use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDef};

use super::TypeName;
use super::{
type_param::{TypeArg, TypeParam},
Substitution, TypeBound,
};
use super::{Type, TypeName};

/// An opaque type element. Contains the unique identifier of its definition.
#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -131,3 +131,9 @@ impl Display for CustomType {
}
}
}

impl From<CustomType> for Type {
fn from(value: CustomType) -> Self {
Self::new_extension(value)
}
}
36 changes: 20 additions & 16 deletions quantinuum-hugr/src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ use crate::macros::impl_box_clone;

use crate::{Hugr, HugrView};

use crate::types::{CustomCheckFailure, CustomType};
use crate::types::{CustomCheckFailure, Type};

/// A value that can be stored as a static constant. Representing core types and
/// extension types.
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "v")]
pub enum Value {
/// An extension constant value, that can check it is of a given [CustomType].
///
/// An extension constant value.
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Extension {
#[allow(missing_docs)]
Expand Down Expand Up @@ -139,10 +138,12 @@ impl<T: CustomConst> From<T> for Value {
}
}

/// Constant value for opaque [`CustomType`]s.
/// Constant value for opaque `[CustomType]`s.
///
/// When implementing this trait, include the `#[typetag::serde]` attribute to
/// enable serialization.
///
/// [CustomType]: crate::types::CustomType
#[typetag::serde(tag = "c")]
pub trait CustomConst:
Send + Sync + std::fmt::Debug + CustomConstBoxClone + Any + Downcast
Expand Down Expand Up @@ -170,7 +171,7 @@ pub trait CustomConst:
}

/// report the type
fn custom_type(&self) -> CustomType;
fn get_type(&self) -> Type;
}

/// Const equality for types that have PartialEq
Expand All @@ -191,19 +192,22 @@ impl_box_clone!(CustomConst, CustomConstBoxClone);
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// A value stored as a serialized blob that can report its own type.
pub struct CustomSerialized {
typ: CustomType,
typ: Type,
value: serde_yaml::Value,
extensions: ExtensionSet,
}

impl CustomSerialized {
/// Creates a new [`CustomSerialized`].
pub fn new(typ: CustomType, value: serde_yaml::Value, exts: impl Into<ExtensionSet>) -> Self {
let extensions = exts.into();
pub fn new(
typ: impl Into<Type>,
value: serde_yaml::Value,
exts: impl Into<ExtensionSet>,
) -> Self {
Self {
typ,
typ: typ.into(),
value,
extensions,
extensions: exts.into(),
}
}

Expand All @@ -226,7 +230,7 @@ impl CustomConst for CustomSerialized {
fn extension_reqs(&self) -> ExtensionSet {
self.extensions.clone()
}
fn custom_type(&self) -> CustomType {
fn get_type(&self) -> Type {
self.typ.clone()
}
}
Expand All @@ -244,9 +248,9 @@ pub(crate) mod test {
use super::*;
use crate::builder::test::simple_dfg_hugr;
use crate::ops::Const;
use crate::std_extensions::arithmetic::float_types::{self, FLOAT64_CUSTOM_TYPE};
use crate::std_extensions::arithmetic::float_types::{self, FLOAT64_TYPE};
use crate::type_row;
use crate::types::{FunctionType, Type};
use crate::types::{CustomType, FunctionType, Type};

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]

Expand All @@ -262,14 +266,14 @@ pub(crate) mod test {
ExtensionSet::singleton(self.0.extension())
}

fn custom_type(&self) -> CustomType {
self.0.clone()
fn get_type(&self) -> Type {
self.0.clone().into()
}
}

pub(crate) fn serialized_float(f: f64) -> Const {
CustomSerialized {
typ: FLOAT64_CUSTOM_TYPE,
typ: FLOAT64_TYPE,
value: serde_yaml::Value::Number(f.into()),
extensions: float_types::EXTENSION_ID.into(),
}
Expand Down