Skip to content

Commit

Permalink
format + lints
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 7, 2024
1 parent 2ee4dff commit 73b700b
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 49 deletions.
1 change: 1 addition & 0 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ class FixedHugr(ConfiguredBaseModel):
extensions: ExtensionSet
hugr: Any


class OpDef(ConfiguredBaseModel, populate_by_name=True):
"""Serializable definition for dynamically loaded operations."""

Expand Down
39 changes: 30 additions & 9 deletions hugr/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,11 @@ impl Debug for SignatureFunc {
pub enum LowerFunc {
/// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s,
/// this will generally only be applicable if the [OpDef] has no [TypeParam]s.
FixedHugr{ extensions: ExtensionSet, hugr: Hugr},
#[allow(missing_docs)]
FixedHugr {
extensions: ExtensionSet,
hugr: Hugr,
},
/// Custom binary function that can (fallibly) compute a Hugr
/// for the particular instance and set of available extensions.
#[serde(skip)]
Expand All @@ -279,7 +283,7 @@ pub enum LowerFunc {
impl Debug for LowerFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FixedHugr{..} => write!(f, "FixedHugr"),
Self::FixedHugr { .. } => write!(f, "FixedHugr"),
Self::CustomFunc(_) => write!(f, "<custom lower>"),
}
}
Expand Down Expand Up @@ -359,7 +363,7 @@ impl OpDef {
self.lower_funcs
.iter()
.flat_map(|f| match f {
LowerFunc::FixedHugr{extensions, hugr} => {
LowerFunc::FixedHugr { extensions, hugr } => {
if available_extensions.is_superset(extensions) {
Some(hugr.clone())
} else {
Expand Down Expand Up @@ -469,29 +473,42 @@ mod test {
mod proptest {
use ::proptest::prelude::*;

use crate::{builder::test::simple_dfg_hugr, extension::{op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc}, types::PolyFuncType};
use crate::{
builder::test::simple_dfg_hugr,
extension::{
op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc,
},
types::PolyFuncType,
};

impl Arbitrary for SignatureFunc {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
any::<PolyFuncType>().prop_map(|x| SignatureFunc::TypeScheme(CustomValidator::from_polyfunc(x))).boxed()
any::<PolyFuncType>()
.prop_map(|x| SignatureFunc::TypeScheme(CustomValidator::from_polyfunc(x)))
.boxed()
}
}

impl Arbitrary for LowerFunc {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
any::<ExtensionSet>().prop_map(|extensions| LowerFunc::FixedHugr{extensions, hugr: simple_dfg_hugr()}).boxed()
any::<ExtensionSet>()
.prop_map(|extensions| LowerFunc::FixedHugr {
extensions,
hugr: simple_dfg_hugr(),
})
.boxed()
}
}

impl Arbitrary for OpDef {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use crate::proptest::{any_serde_yaml_value, any_string, any_smolstr};
use crate::proptest::{any_serde_yaml_value, any_smolstr, any_string};
use proptest::collection::{hash_map, vec};
let signature_func: BoxedStrategy<SignatureFunc> = any::<SignatureFunc>();
let lower_funcs: BoxedStrategy<LowerFunc> = any::<LowerFunc>();
Expand All @@ -514,7 +531,8 @@ mod test {
lower_funcs,
constant_folder: None,
},
).boxed()
)
.boxed()
}
}
}
Expand Down Expand Up @@ -549,7 +567,10 @@ mod test {
let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var]));

let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?;
def.add_lower_func(LowerFunc::FixedHugr{extensions: ExtensionSet::new(), hugr: Hugr::default()});
def.add_lower_func(LowerFunc::FixedHugr {
extensions: ExtensionSet::new(),
hugr: Hugr::default(),
});
def.add_misc("key", Default::default());
assert_eq!(def.description(), "desc");
assert_eq!(def.lower_funcs.len(), 1);
Expand Down
4 changes: 3 additions & 1 deletion hugr/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ impl PartialEq for SimpleOpDef {
let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
lfs.iter()
.map(|lf| match lf {
LowerFunc::FixedHugr{extensions, hugr} => Some((extensions.clone(), hugr.clone())),
LowerFunc::FixedHugr { extensions, hugr } => {
Some((extensions.clone(), hugr.clone()))
}
_ => None,
})
.collect::<Option<Vec<(ExtensionSet, Hugr)>>>()
Expand Down
10 changes: 5 additions & 5 deletions hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
mod custom;

use self::custom::CustomSerializedError;

use super::{NamedOp, OpName, OpTrait, StaticTag};
use super::{OpTag, OpType};
use crate::extension::ExtensionSet;
Expand Down Expand Up @@ -444,7 +442,9 @@ mod test {

/// A [`CustomSerialized`] encoding a [`FLOAT64_TYPE`] float constant used in testing.
pub(crate) fn serialized_float(f: f64) -> Value {
CustomSerialized::try_from_custom_const(ConstF64::new(f)).unwrap().into()
CustomSerialized::try_from_custom_const(ConstF64::new(f))
.unwrap()
.into()
}

fn test_registry() -> ExtensionRegistry {
Expand Down Expand Up @@ -586,8 +586,8 @@ mod test {
ex_id.clone(),
TypeBound::Eq,
);
let yaml_const: Value = CustomSerialized::new
(typ_int.clone(), YamlValue::Number(6.into()), ex_id.clone())
let yaml_const: Value =
CustomSerialized::new(typ_int.clone(), YamlValue::Number(6.into()), ex_id.clone())
.into();
let classic_t = Type::new_extension(typ_int.clone());
assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq);
Expand Down
89 changes: 55 additions & 34 deletions hugr/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ pub fn downcast_equal_consts<T: CustomConst + PartialEq>(
}
}


/// Serialize any CustomConst using the `impl Serialize for &dyn CustomConst`.
/// In particular this works on `&dyn CustomConst` and `Box<dyn CustomConst>::Target`.
/// See tests below
Expand All @@ -84,13 +83,11 @@ fn deserialize_custom_const<CC: CustomConst>(
) -> Result<CC, serde_yaml::Error> {
match deserialize_dyn_custom_const(value)?.downcast::<CC>() {
Ok(cc) => Ok(*cc),
Err(dyn_cc) => {
Err(<serde_yaml::Error as serde::de::Error>::custom(format!(
"Failed to deserialize [{}]: {:?}",
std::any::type_name::<CC>(),
dyn_cc
)))
}
Err(dyn_cc) => Err(<serde_yaml::Error as serde::de::Error>::custom(format!(
"Failed to deserialize [{}]: {:?}",
std::any::type_name::<CC>(),
dyn_cc
))),
}
}

Expand All @@ -100,8 +97,6 @@ fn deserialize_dyn_custom_const(
serde_yaml::from_value(value)
}



impl_downcast!(CustomConst);
impl_box_clone!(CustomConst, CustomConstBoxClone);

Expand Down Expand Up @@ -171,13 +166,14 @@ impl CustomSerialized {
}

/// TODO
pub fn try_from_custom_const_box(cc: Box<dyn CustomConst>) -> Result<Self, CustomSerializedError> {
pub fn try_from_custom_const_box(
cc: Box<dyn CustomConst>,
) -> Result<Self, CustomSerializedError> {
match cc.downcast::<Self>() {
Ok(x) => Ok(*x),
Err(cc) => {
let (typ, extension_reqs) = (cc.get_type(), cc.extension_reqs());
let value =
serialize_custom_const(cc.as_ref())
let value = serialize_custom_const(cc.as_ref())
.map_err(|err| CustomSerializedError::new_ser(err, cc))?;
Ok(Self::new(typ, value, extension_reqs))
}
Expand All @@ -189,8 +185,8 @@ impl CustomSerialized {
let (typ, extensions) = (self.get_type().clone(), self.extension_reqs());
// ideally we would not have to clone, but serde_json does not allow us
// to recover the value from the error
let cc_box = deserialize_dyn_custom_const(self.value.clone())
.unwrap_or_else(|_| Box::new(self));
let cc_box =
deserialize_dyn_custom_const(self.value.clone()).unwrap_or_else(|_| Box::new(self));
assert_eq!(cc_box.get_type(), typ);
assert_eq!(cc_box.extension_reqs(), extensions);
cc_box
Expand Down Expand Up @@ -254,25 +250,29 @@ impl From<CustomSerialized> for Box<dyn CustomConst> {
#[cfg(test)]
mod test {


use downcast_rs::Downcast;
use rstest::rstest;

use crate::{
extension::{prelude::{ConstUsize, USIZE_T}, ExtensionSet}, ops::{constant::custom::{
deserialize_custom_const, deserialize_dyn_custom_const, serialize_custom_const,
}, Value}, std_extensions::{arithmetic::int_types::ConstInt, collections::ListValue}
extension::{
prelude::{ConstUsize, USIZE_T},
ExtensionSet,
},
ops::{
constant::custom::{deserialize_dyn_custom_const, serialize_custom_const},
Value,
},
std_extensions::{arithmetic::int_types::ConstInt, collections::ListValue},
};

use super::{CustomConst, CustomConstBoxClone, CustomSerialized};

struct SerializeCustomConstExample<CC: CustomConst + serde::Serialize + 'static> {
cc: CC,
tag: &'static str,
yaml: serde_yaml::Value
yaml: serde_yaml::Value,
}

impl<CC : CustomConst + serde::Serialize + 'static> SerializeCustomConstExample<CC> {
impl<CC: CustomConst + serde::Serialize + 'static> SerializeCustomConstExample<CC> {
fn new(cc: CC, tag: &'static str) -> Self {
let yaml = serde_yaml::to_value(&cc).unwrap();
Self { cc, tag, yaml }
Expand All @@ -284,19 +284,36 @@ mod test {
}

fn ser_cc_ex2() -> SerializeCustomConstExample<ListValue> {
SerializeCustomConstExample::new(ListValue::new(USIZE_T, [ConstUsize::new(1), ConstUsize::new(2)].into_iter().map(Value::extension)), "ListValue")
SerializeCustomConstExample::new(
ListValue::new(
USIZE_T,
[ConstUsize::new(1), ConstUsize::new(2)]
.into_iter()
.map(Value::extension),
),
"ListValue",
)
}

fn ser_cc_ex3() -> SerializeCustomConstExample<CustomSerialized> {
SerializeCustomConstExample::new(CustomSerialized::new(USIZE_T, serde_yaml::Value::Null, ExtensionSet::default()), "CustomSerialized")
SerializeCustomConstExample::new(
CustomSerialized::new(USIZE_T, serde_yaml::Value::Null, ExtensionSet::default()),
"CustomSerialized",
)
}

#[rstest]
#[case(ser_cc_ex1())]
#[case(ser_cc_ex2())]
#[case(ser_cc_ex3())]
fn test_serialize_custom_const<CC: CustomConst + serde::Serialize + 'static + Sized>(#[case] example : SerializeCustomConstExample<CC>) {
let expected_yaml: serde_yaml::Value = [("c".into(), example.tag.into()), ("v".into(), example.yaml)].into_iter().collect::<serde_yaml::Mapping>().into();
fn test_serialize_custom_const<CC: CustomConst + serde::Serialize + 'static + Sized>(
#[case] example: SerializeCustomConstExample<CC>,
) {
let expected_yaml: serde_yaml::Value =
[("c".into(), example.tag.into()), ("v".into(), example.yaml)]
.into_iter()
.collect::<serde_yaml::Mapping>()
.into();

let yaml_by_ref = serialize_custom_const(&example.cc as &CC).unwrap();
assert_eq!(expected_yaml, yaml_by_ref);
Expand All @@ -323,7 +340,8 @@ mod test {
#[test]
fn custom_serialized_from_into_custom_serialised() {
let const_int = ConstInt::new_s(4, 1).unwrap();
let cs0: CustomSerialized = CustomSerialized::try_from_custom_const_ref(&const_int).unwrap();
let cs0: CustomSerialized =
CustomSerialized::try_from_custom_const_ref(&const_int).unwrap();

let cs1 = CustomSerialized::try_from_custom_const_ref(&cs0).unwrap();
assert_eq!(&cs0, &cs1);
Expand All @@ -341,7 +359,7 @@ mod test {
assert_eq!(&serialize_custom_const(&const_int).unwrap(), cs.value());

let deser_const_int: ConstInt = {
let dyn_box: Box<dyn CustomConst> = cs.try_into().unwrap();
let dyn_box: Box<dyn CustomConst> = cs.into();
*dyn_box.downcast().unwrap()
};
assert_eq!(const_int, deser_const_int)
Expand All @@ -350,7 +368,8 @@ mod test {
#[test]
fn nested_custom_serialized() {
let const_int = ConstInt::new_s(4, 1).unwrap();
let cs_inner: CustomSerialized = CustomSerialized::try_from_custom_const_ref(&const_int).unwrap();
let cs_inner: CustomSerialized =
CustomSerialized::try_from_custom_const_ref(&const_int).unwrap();

let cs_inner_ser = serialize_custom_const(&cs_inner).unwrap();

Expand Down Expand Up @@ -383,7 +402,7 @@ mod test {
use crate::{
extension::ExtensionSet,
ops::constant::CustomSerialized,
proptest::{any_serde_yaml_mapping, any_serde_yaml_value, any_string},
proptest::{any_serde_yaml_value, any_string},
types::Type,
};

Expand All @@ -393,10 +412,12 @@ mod test {
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
let typ = any::<Type>();
let extensions = any::<ExtensionSet>();
let value = (any_serde_yaml_value(), any_string())
.prop_map(|(value, c)| {
[("c".into(), c.into()), ("v".into(), value)].into_iter().collect::<serde_yaml::Mapping>().into()
});
let value = (any_serde_yaml_value(), any_string()).prop_map(|(value, c)| {
[("c".into(), c.into()), ("v".into(), value)]
.into_iter()
.collect::<serde_yaml::Mapping>()
.into()
});
(typ, value, extensions)
.prop_map(|(typ, value, extensions)| CustomSerialized {
typ,
Expand Down

0 comments on commit 73b700b

Please sign in to comment.