Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 7, 2024
1 parent 6a7f2b0 commit a4bfc32
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 80 deletions.
30 changes: 8 additions & 22 deletions hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,9 @@ impl ExtensionValue {
}
}

impl From<CustomSerialized> for ExtensionValue {
fn from(x: CustomSerialized) -> Self {
Self { v: x.into() }
}
}

impl TryFrom<ExtensionValue> for CustomSerialized {
type Error = CustomSerializedError;
fn try_from(x: ExtensionValue) -> Result<Self, Self::Error> {
x.v.as_ref().try_into()
impl<CC: CustomConst> From<CC> for ExtensionValue {
fn from(x: CC) -> Self {
Self::new(x)
}
}

Expand Down Expand Up @@ -390,8 +383,6 @@ impl Value {
}
}

// [KnownTypeConst] is guaranteed to be the right type, so can be constructed
// without initial type check.
impl<T> From<T> for Value
where
T: CustomConst,
Expand Down Expand Up @@ -453,12 +444,7 @@ mod test {

/// A [`CustomSerialized`] encoding a [`FLOAT64_TYPE`] float constant used in testing.
pub(crate) fn serialized_float(f: f64) -> Value {
CustomSerialized::new(
FLOAT64_TYPE,
serde_yaml::Value::Number(f.into()),
float_types::EXTENSION_ID,
)
.into()
CustomSerialized::try_from_custom_const(ConstF64::new(f)).unwrap().into()
}

fn test_registry() -> ExtensionRegistry {
Expand All @@ -480,7 +466,7 @@ mod test {
0,
[
CustomTestValue(USIZE_CUSTOM_T).into(),
serialized_float(5.1),
ConstF64::new(5.1).into(),
],
pred_ty.clone(),
)?);
Expand Down Expand Up @@ -565,7 +551,7 @@ mod test {
#[rstest]
#[case(Value::unit(), Type::UNIT, "const:seq:{}")]
#[case(const_usize(), USIZE_T, "const:custom:ConstUsize(")]
#[case(serialized_float(17.4), FLOAT64_TYPE, "const:custom:yaml:Number(17.4)")]
// #[case(serialized_float(17.4), FLOAT64_TYPE, "const:custom:yaml:Number(17.4)")]
#[case(const_tuple(), Type::new_tuple(type_row![USIZE_T, FLOAT64_TYPE]), "const:seq:{")]
fn const_type(
#[case] const_value: Value,
Expand Down Expand Up @@ -600,8 +586,8 @@ mod test {
ex_id.clone(),
TypeBound::Eq,
);
let yaml_const: Value =
CustomSerialized::new(typ_int.clone(), YamlValue::Number(6.into()), ex_id.clone())
let yaml_const: Value = CustomSerialized::new
(typ_int.clone(), YamlValue::Number(6.into()), ex_id.clone())
.into();
let classic_t = Type::new_extension(typ_int.clone());
assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq);
Expand Down
165 changes: 107 additions & 58 deletions hugr/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,24 @@ impl PartialEq for dyn CustomConst {
}
}

impl dyn CustomConst {
fn serialize_dyn_custom_const(&self) -> Result<serde_yaml::Value, serde_yaml::Error> {
serde_yaml::to_value(self)
/// Const equality for types that have PartialEq
pub fn downcast_equal_consts<T: CustomConst + PartialEq>(
constant: &T,
other: &dyn CustomConst,
) -> bool {
if let Some(other) = other.as_any().downcast_ref::<T>() {
constant == other
} else {
false
}
}

fn serialize_custom_const(cc: &impl CustomConst) -> Result<serde_yaml::Value, serde_yaml::Error> {
(cc as &dyn CustomConst).serialize_dyn_custom_const()
}

fn deserialize_dyn_custom_const(
value: serde_yaml::Value,
) -> Result<Box<dyn CustomConst>, serde_yaml::Error> {
serde_yaml::from_value(value)
/// Serialize any CustomConst using the `impl Serialize for &dyn CustomConst`.
/// In particular this works on `&dyn CustomConst` and `Box<dyn CustomConst>::Target`.
/// See tests below
fn serialize_custom_const(cc: &dyn CustomConst) -> Result<serde_yaml::Value, serde_yaml::Error> {
serde_yaml::to_value(cc)
}

fn deserialize_custom_const<CC: CustomConst>(
Expand All @@ -81,7 +85,6 @@ fn deserialize_custom_const<CC: CustomConst>(
match deserialize_dyn_custom_const(value)?.downcast::<CC>() {
Ok(cc) => Ok(*cc),
Err(dyn_cc) => {
// TODO we should return dyn_cc in the error
Err(<serde_yaml::Error as serde::de::Error>::custom(format!(
"Failed to deserialize [{}]: {:?}",
std::any::type_name::<CC>(),
Expand All @@ -91,18 +94,14 @@ fn deserialize_custom_const<CC: CustomConst>(
}
}

/// Const equality for types that have PartialEq
pub fn downcast_equal_consts<T: CustomConst + PartialEq>(
constant: &T,
other: &dyn CustomConst,
) -> bool {
if let Some(other) = other.as_any().downcast_ref::<T>() {
constant == other
} else {
false
}
fn deserialize_dyn_custom_const(
value: serde_yaml::Value,
) -> Result<Box<dyn CustomConst>, serde_yaml::Error> {
serde_yaml::from_value(value)
}



impl_downcast!(CustomConst);
impl_box_clone!(CustomConst, CustomConstBoxClone);

Expand Down Expand Up @@ -160,49 +159,45 @@ impl CustomSerialized {
}

/// TODO
pub fn from_custom_const_ref(
pub fn try_from_custom_const_ref(
cc: &(impl CustomConst + ?Sized),
) -> Result<Self, CustomSerializedError> {
Self::from_custom_const_box(cc.clone_box())
Self::try_from_custom_const_box(cc.clone_box())
}

/// TODO
pub fn from_custom_const(cc: impl CustomConst) -> Result<Self, CustomSerializedError> {
Self::from_custom_const_box(Box::new(cc))
pub fn try_from_custom_const(cc: impl CustomConst) -> Result<Self, CustomSerializedError> {
Self::try_from_custom_const_box(Box::new(cc))
}

/// TODO
pub fn from_custom_const_box(cc: Box<dyn CustomConst>) -> Result<Self, CustomSerializedError> {
pub fn try_from_custom_const_box(cc: Box<dyn CustomConst>) -> Result<Self, CustomSerializedError> {
match cc.downcast::<Self>() {
Ok(x) => Ok(*x),
Err(cc) => {
let (typ, extension_reqs) = (cc.get_type(), cc.extension_reqs());
let value = cc
.serialize_dyn_custom_const()
let value =
serialize_custom_const(cc.as_ref())
.map_err(|err| CustomSerializedError::new_ser(err, cc))?;
Ok(Self::new(typ, value, extension_reqs))
}
}
}

/// TODO
pub fn into_custom_const_box(self) -> Result<Box<dyn CustomConst>, CustomSerializedError> {
let CustomSerialized {
typ,
value,
extensions,
} = self;
pub fn into_custom_const_box(self) -> Box<dyn CustomConst> {
let (typ, extensions) = (self.get_type().clone(), self.extension_reqs());
// ideally we would not have to clone, but serde_json does not allow us
// to recover the value from the error
let cc_box = deserialize_dyn_custom_const(value.clone())
.map_err(|err| CustomSerializedError::new_de(err, value))?;
let cc_box = deserialize_dyn_custom_const(self.value.clone())
.unwrap_or_else(|_| Box::new(self));
assert_eq!(cc_box.get_type(), typ);
assert_eq!(cc_box.extension_reqs(), extensions);
Ok(cc_box)
cc_box
}

/// TODO
pub fn into_custom_const<CC: CustomConst>(self) -> Result<CC, CustomSerializedError> {
pub fn try_into_custom_const<CC: CustomConst>(self) -> Result<CC, CustomSerializedError> {
let CustomSerialized {
typ,
value,
Expand Down Expand Up @@ -239,70 +234,124 @@ impl CustomConst for CustomSerialized {
impl TryFrom<&dyn CustomConst> for CustomSerialized {
type Error = CustomSerializedError;
fn try_from(value: &dyn CustomConst) -> Result<Self, Self::Error> {
Self::from_custom_const_box(value.clone_box())
Self::try_from_custom_const_ref(value)
}
}

impl TryFrom<Box<dyn CustomConst>> for CustomSerialized {
type Error = CustomSerializedError;
fn try_from(value: Box<dyn CustomConst>) -> Result<Self, Self::Error> {
Self::from_custom_const_box(value)
Self::try_from_custom_const_box(value)
}
}

impl From<CustomSerialized> for Box<dyn CustomConst> {
fn from(value: CustomSerialized) -> Self {
let (typ, extension_reqs) = (value.get_type(), value.extension_reqs());
value.into_custom_const_box().unwrap_or_else(|err| match err {
CustomSerializedError::DeserializePayloadError { payload, .. } =>
Box::new(CustomSerialized::new(typ, payload, extension_reqs)),
_ => panic!("CustomSerialized::into_custom_const_box returned an error other than DeserializePayloadError")
})
fn from(cs: CustomSerialized) -> Self {
cs.into_custom_const_box()
}
}

#[cfg(test)]
mod test {


use downcast_rs::Downcast;
use rstest::rstest;

use crate::{
ops::constant::custom::{
extension::{prelude::{ConstUsize, USIZE_T}, ExtensionSet}, ops::{constant::custom::{
deserialize_custom_const, deserialize_dyn_custom_const, serialize_custom_const,
},
std_extensions::arithmetic::int_types::ConstInt,
}, Value}, std_extensions::{arithmetic::int_types::ConstInt, collections::ListValue}
};

use super::{CustomConst, CustomConstBoxClone, CustomSerialized};

struct SerializeCustomConstExample<CC: CustomConst + serde::Serialize + 'static> {
cc: CC,
tag: &'static str,
yaml: serde_yaml::Value
}

impl<CC : CustomConst + serde::Serialize + 'static> SerializeCustomConstExample<CC> {
fn new(cc: CC, tag: &'static str) -> Self {
let yaml = serde_yaml::to_value(&cc).unwrap();
Self { cc, tag, yaml }
}
}

fn ser_cc_ex1() -> SerializeCustomConstExample<ConstUsize> {
SerializeCustomConstExample::new(ConstUsize::new(12), "ConstUsize")
}

fn ser_cc_ex2() -> SerializeCustomConstExample<ListValue> {
SerializeCustomConstExample::new(ListValue::new(USIZE_T, [ConstUsize::new(1), ConstUsize::new(2)].into_iter().map(Value::extension)), "ListValue")
}

fn ser_cc_ex3() -> SerializeCustomConstExample<CustomSerialized> {
SerializeCustomConstExample::new(CustomSerialized::new(USIZE_T, serde_yaml::Value::Null, ExtensionSet::default()), "CustomSerialized")
}

#[rstest]
#[case(ser_cc_ex1())]
#[case(ser_cc_ex2())]
#[case(ser_cc_ex3())]
fn test_serialize_custom_const<CC: CustomConst + serde::Serialize + 'static + Sized>(#[case] example : SerializeCustomConstExample<CC>) {
let expected_yaml: serde_yaml::Value = [("c".into(), example.tag.into()), ("v".into(), example.yaml)].into_iter().collect::<serde_yaml::Mapping>().into();

let yaml_by_ref = serialize_custom_const(&example.cc as &CC).unwrap();
assert_eq!(expected_yaml, yaml_by_ref);

let yaml_by_dyn_ref = serialize_custom_const(&example.cc as &dyn CustomConst).unwrap();
assert_eq!(expected_yaml, yaml_by_dyn_ref);
}

#[test]
fn t1() {
fn custom_serialized_from_into_custom_const() {
let const_int = ConstInt::new_s(4, 1).unwrap();

// verify we have a CustomSerialized:
let cs: CustomSerialized = CustomSerialized::from_custom_const_ref(&const_int).unwrap();
let cs: CustomSerialized = CustomSerialized::try_from_custom_const_ref(&const_int).unwrap();

assert_eq!(const_int.get_type(), cs.get_type());
assert_eq!(const_int.extension_reqs(), cs.extension_reqs());
assert_eq!(&serialize_custom_const(&const_int).unwrap(), cs.value());

let deser_const_int: ConstInt = deserialize_custom_const(cs.value().clone()).unwrap();
let deser_const_int: ConstInt = cs.try_into_custom_const().unwrap();

assert_eq!(const_int, deser_const_int);
}

#[test]
fn custom_serialised_try_from() {
fn custom_serialized_from_into_custom_serialised() {
let const_int = ConstInt::new_s(4, 1).unwrap();
let cs0: CustomSerialized = CustomSerialized::try_from_custom_const_ref(&const_int).unwrap();

let cs1 = CustomSerialized::try_from_custom_const_ref(&cs0).unwrap();
assert_eq!(&cs0, &cs1);

let deser_const_int: ConstInt = cs0.try_into_custom_const().unwrap();
assert_eq!(&const_int, &deser_const_int);
}

#[test]
fn custom_serialized_try_from_dyn_custom_const() {
let const_int = ConstInt::new_s(4, 1).unwrap();
let cs: CustomSerialized = const_int.clone_box().try_into().unwrap();
assert_eq!(const_int.get_type(), cs.get_type());
assert_eq!(const_int.extension_reqs(), cs.extension_reqs());

assert_eq!(&serialize_custom_const(&const_int).unwrap(), cs.value());

let deser_const_int: ConstInt = {
let dyn_box: Box<dyn CustomConst> = cs.try_into().unwrap();
*dyn_box.downcast().unwrap()
};
assert_eq!(const_int, deser_const_int)
}

#[test]
fn t2() {
fn nested_custom_serialized() {
let const_int = ConstInt::new_s(4, 1).unwrap();
let cs_inner: CustomSerialized = const_int.clone_box().try_into().unwrap();
let cs_inner: CustomSerialized = CustomSerialized::try_from_custom_const_ref(&const_int).unwrap();

let cs_inner_ser = serialize_custom_const(&cs_inner).unwrap();

// cs_outer is a CustomSerialized of a CustomSerialized of a ConstInt
Expand Down

0 comments on commit a4bfc32

Please sign in to comment.