Skip to content

Commit

Permalink
Add proptest for roundtrip testing of serialisation schema
Browse files Browse the repository at this point in the history
Note that `impl Arbitrary for OpDef` is not yet implemented
  • Loading branch information
doug-q committed May 8, 2024
1 parent 94fbb8f commit fe3450d
Show file tree
Hide file tree
Showing 26 changed files with 726 additions and 119 deletions.
9 changes: 7 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,10 @@ missing_docs = "warn"
portgraph = { version = "0.12.0" }
insta = { version = "1.34.0" }

[profile.dev.package.insta]
opt-level = 3
[profile.dev.package]
insta.opt-level = 3
rand_chacha.opt-level = 3
regex.opt-level = 3
regex-automata.opt-level = 3
regex-syntax.opt-level = 3
proptest.opt-level = 3
3 changes: 3 additions & 0 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ class Config:
"a set of required resources."
)
}
json_schema_extra = {
"required": ["t", "input", "output"],
}


class PolyFuncType(ConfiguredBaseModel):
Expand Down
5 changes: 5 additions & 0 deletions hugr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ bench = false
path = "src/lib.rs"

[features]
default = []
extension_inference = []
proptest = ["dep:proptest","dep:proptest-derive","dep:regex-syntax"]

[dependencies]
portgraph = { workspace = true, features = ["serde", "petgraph"] }
Expand Down Expand Up @@ -52,6 +54,9 @@ delegate = "0.12.0"
paste = "1.0"
strum = "0.26.1"
strum_macros = "0.26.1"
proptest = { version = "1.4.0", optional = true }
proptest-derive = { version = "0.4.0", optional = true}
regex-syntax = { version = "0.8.3", optional = true}

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
32 changes: 32 additions & 0 deletions hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,35 @@ impl FromIterator<ExtensionId> for ExtensionSet {
Self(BTreeSet::from_iter(iter))
}
}

#[cfg(test)]
mod test {

#[cfg(feature = "proptest")]
mod proptest {

use ::proptest::{collection::hash_set, prelude::*};

use super::super::{ExtensionId, ExtensionSet};
impl Arbitrary for ExtensionSet {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
let vars = hash_set(0..10usize, 0..3);
let extensions = hash_set(any::<ExtensionId>(), 0..3);
(vars, extensions)
.prop_map(|(vars, extensions)| {
let mut r = Self::new();
for v in vars {
r.insert_type_var(v);
}
for e in extensions {
r.insert(&e)
}
r
})
.boxed()
}
}
}
}
2 changes: 1 addition & 1 deletion hugr/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pub mod hugrmut;

mod ident;
pub(crate) mod ident;
pub mod rewrite;
pub mod serialize;
pub mod validate;
Expand Down
41 changes: 40 additions & 1 deletion hugr/src/hugr/ident.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ use regex::Regex;
use smol_str::SmolStr;
use thiserror::Error;

pub static PATH_COMPONENT_REGEX_STR: &str = r"[\w--\d]\w*";
#[cfg(all(test, feature = "proptest"))]
pub static PATH_COMPONENT_NICE_REGEX_STR: &str = r"[[:alpha:]][[[:alpha:]]0-9]*";
lazy_static! {
pub static ref PATH_REGEX: Regex = Regex::new(r"^[\w--\d]\w*(\.[\w--\d]\w*)*$").unwrap();
pub static ref PATH_REGEX: Regex =
Regex::new(&format!(r"^{0}(\.{0})*$", PATH_COMPONENT_REGEX_STR)).unwrap();
}

#[derive(
Expand All @@ -23,6 +27,7 @@ lazy_static! {
serde::Deserialize,
)]
/// A non-empty dot-separated list of valid identifiers
pub struct IdentList(SmolStr);

impl IdentList {
Expand Down Expand Up @@ -75,6 +80,40 @@ pub struct InvalidIdentifier(SmolStr);

#[cfg(test)]
mod test {

#[cfg(feature = "proptest")]
mod proptest {
use crate::hugr::ident::IdentList;
use ::proptest::prelude::*;
impl Arbitrary for super::IdentList {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use crate::proptest::any_ident_string;
use proptest::collection::vec;
// we shrink to more readable (i.e. :alpha:) names
vec(any_ident_string(), 1..2)
.prop_map(|vs| {
IdentList::new(
itertools::intersperse(
vs.into_iter().map(Into::<String>::into),
".".into(),
)
.collect::<String>(),
)
.unwrap()
})
.boxed()
}
}
proptest! {
#[test]
fn arbitrary_identlist_valid((IdentList(ident_list)): IdentList) {
assert!(IdentList::new(ident_list).is_ok())
}
}
}

use super::IdentList;

#[test]
Expand Down
30 changes: 0 additions & 30 deletions hugr/src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,6 @@ struct SerHugrV1 {
encoder: Option<String>,
}

/// Version 1 of the Testing HUGR serialisation format, see `testing_hugr.py`.
#[cfg(test)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
struct SerTestingV1 {
typ: Option<crate::types::Type>,
sum_type: Option<crate::types::SumType>,
poly_func_type: Option<crate::types::PolyFuncType>,
value: Option<crate::ops::Value>,
optype: Option<NodeSer>,
}

macro_rules! impl_sertesting_from {
($typ:ty, $field:ident) => {
#[cfg(test)]
impl From<$typ> for SerTestingV1 {
fn from(v: $typ) -> Self {
let mut r: Self = Default::default();
r.$field = Some(v);
r
}
}
};
}

impl_sertesting_from!(crate::types::Type, typ);
impl_sertesting_from!(crate::types::SumType, sum_type);
impl_sertesting_from!(crate::types::PolyFuncType, poly_func_type);
impl_sertesting_from!(crate::ops::Value, value);
impl_sertesting_from!(NodeSer, optype);

/// Errors that can occur while serializing a HUGR.
#[derive(Debug, Clone, PartialEq, Error)]
#[non_exhaustive]
Expand Down
140 changes: 62 additions & 78 deletions hugr/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,39 @@ use crate::builder::{
test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T};
use crate::extension::prelude::{BOOL_T, USIZE_T};
use crate::extension::simple_op::MakeRegisteredOp;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{self, Value};
use crate::ops::{dataflow::IOTrait, Input, Module, Noop, Output, DFG};
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES};

use crate::std_extensions::logic::NotOp;
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType, SumType, Type, TypeBound};

use crate::types::{FunctionType, Type};
use crate::{type_row, OutgoingPort};
use itertools::Itertools;
use jsonschema::{Draft, JSONSchema};
use lazy_static::lazy_static;
use portgraph::LinkView;
use portgraph::{multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, UnmanagedDenseMap};
use rstest::rstest;

const NAT: Type = crate::extension::prelude::USIZE_T;
const QB: Type = crate::extension::prelude::QB_T;

/// Version 1 of the Testing HUGR serialisation format, see `testing_hugr.py`.
#[cfg(test)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
struct SerTestingV1 {
typ: Option<crate::types::Type>,
sum_type: Option<crate::types::SumType>,
poly_func_type: Option<crate::types::PolyFuncType>,
value: Option<crate::ops::Value>,
optype: Option<NodeSer>,
}

type TestingModel = SerTestingV1;

macro_rules! include_schema {
Expand Down Expand Up @@ -61,6 +70,25 @@ include_schema!(
"../../../../specification/schema/testing_hugr_schema_strict_v1.json"
);

macro_rules! impl_sertesting_from {
($typ:ty, $field:ident) => {
#[cfg(test)]
impl From<$typ> for TestingModel {
fn from(v: $typ) -> Self {
let mut r: Self = Default::default();
r.$field = Some(v);
r
}
}
};
}

impl_sertesting_from!(crate::types::Type, typ);
impl_sertesting_from!(crate::types::SumType, sum_type);
impl_sertesting_from!(crate::types::PolyFuncType, poly_func_type);
impl_sertesting_from!(crate::ops::Value, value);
impl_sertesting_from!(NodeSer, optype);

#[test]
fn empty_hugr_serialize() {
let hg = Hugr::default();
Expand Down Expand Up @@ -149,6 +177,7 @@ pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr {
new_hugr
}

#[allow(unused)]
fn check_testing_roundtrip(t: impl Into<TestingModel>) {
let before = Versioned::new(t.into());
let after_strict = ser_roundtrip_validate(&before, Some(&TESTING_SCHEMA_STRICT));
Expand Down Expand Up @@ -346,79 +375,34 @@ fn serialize_types_roundtrip() {
assert_eq!(ser_roundtrip(&t), t);
}

#[rstest]
#[case(BOOL_T)]
#[case(USIZE_T)]
#[case(INT_TYPES[2].clone())]
#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Any)))]
#[case(Type::new_var_use(2, TypeBound::Copyable))]
#[case(Type::new_tuple(type_row![BOOL_T,QB_T]))]
#[case(Type::new_sum([type_row![BOOL_T,QB_T], type_row![Type::new_unit_sum(4)]]))]
#[case(Type::new_function(FunctionType::new_endo(type_row![QB_T,BOOL_T,USIZE_T])))]
fn roundtrip_type(#[case] typ: Type) {
check_testing_roundtrip(typ);
}

#[rstest]
#[case(SumType::new_unary(2))]
#[case(SumType::new([type_row![USIZE_T, QB_T], type_row![]]))]
fn roundtrip_sumtype(#[case] sum_type: SumType) {
check_testing_roundtrip(sum_type);
}

#[rstest]
#[case(Value::unit())]
#[case(Value::true_val())]
#[case(Value::unit_sum(3,5).unwrap())]
#[case(Value::extension(ConstF64::new(-1.5)))]
#[case(Value::extension(ConstF64::new(0.0)))]
#[case(Value::extension(ConstF64::new(-0.0)))]
// These cases fail
// #[case(Value::extension(ConstF64::new(std::f64::NAN)))]
// #[case(Value::extension(ConstF64::new(std::f64::INFINITY)))]
// #[case(Value::extension(ConstF64::new(std::f64::NEG_INFINITY)))]
#[case(Value::extension(ConstF64::new(std::f64::MIN_POSITIVE)))]
#[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())]
#[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))]
#[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())]
fn roundtrip_value(#[case] value: Value) {
check_testing_roundtrip(value);
}
#[cfg(feature = "proptest")]
mod proptest {
use super::super::NodeSer;
use super::check_testing_roundtrip;
use crate::extension::ExtensionSet;
use crate::ops::{OpType, Value};
use crate::types::{PolyFuncType, Type};
use proptest::prelude::*;

proptest! {
#[test]
fn prop_roundtrip_type(t: Type) {
check_testing_roundtrip(t)
}

fn polyfunctype1() -> PolyFuncType {
let mut extension_set = ExtensionSet::new();
extension_set.insert_type_var(1);
let function_type = FunctionType::new_endo(type_row![]).with_extension_delta(extension_set);
PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type)
}
#[test]
fn prop_roundtrip_poly_func_type(t: PolyFuncType) {
check_testing_roundtrip(t)
}

#[rstest]
#[case(FunctionType::new_endo(type_row![]).into())]
#[case(polyfunctype1())]
#[case(PolyFuncType::new([TypeParam::Opaque { ty: int_custom_type(TypeArg::BoundedNat { n: 1 }) }], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))]
#[case(PolyFuncType::new([TypeBound::Eq.into()], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Eq)])))]
#[case(PolyFuncType::new([TypeParam::List { param: Box::new(TypeBound::Any.into()) }], FunctionType::new_endo(type_row![])))]
#[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FunctionType::new_endo(type_row![])))]
fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) {
check_testing_roundtrip(poly_func_type)
}
#[test]
fn prop_roundtrip_value(t: Value) {
check_testing_roundtrip(t)
}

#[rstest]
#[case(ops::Module)]
#[case(ops::FuncDefn { name: "polyfunc1".into(), signature: polyfunctype1()})]
#[case(ops::FuncDecl { name: "polyfunc2".into(), signature: polyfunctype1()})]
#[case(ops::AliasDefn { name: "aliasdefn".into(), definition: Type::new_unit_sum(4)})]
#[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Any})]
#[case(ops::Const::new(Value::false_val()))]
#[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))]
#[case(ops::Input::new(type_row![Type::new_var_use(3,TypeBound::Eq)]))]
#[case(ops::Output::new(vec![Type::new_function(FunctionType::new_endo(type_row![]))]))]
#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(&PRELUDE_ID)} ], &EMPTY_REG).unwrap())]
#[case(ops::CallIndirect { signature : FunctionType::new_endo(type_row![BOOL_T]) })]
fn roundtrip_optype(#[case] optype: impl Into<OpType> + std::fmt::Debug) {
check_testing_roundtrip(NodeSer {
parent: portgraph::NodeIndex::new(0).into(),
input_extensions: None,
op: optype.into(),
});
#[test]
fn prop_roundtrip_optype(op in ((0..(std::u32::MAX / 2) as usize). prop_map(|x| portgraph::NodeIndex::new(x).into()), any::<Option<ExtensionSet>>(), any::<OpType>()).prop_map(|(parent, input_extensions, op)| NodeSer { parent, input_extensions, op })) {
check_testing_roundtrip(op)
}
}
}
3 changes: 3 additions & 0 deletions hugr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,6 @@ pub use crate::core::{
};
pub use crate::extension::Extension;
pub use crate::hugr::{Hugr, HugrView, SimpleReplacement};

#[cfg(all(feature = "proptest", test))]
pub mod proptest;
1 change: 1 addition & 0 deletions hugr/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub use tag::OpTag;

#[enum_dispatch(OpTrait, NamedOp, ValidateOp, OpParent)]
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(all(test, feature = "proptest"), derive(proptest_derive::Arbitrary))]
/// The concrete operation types for a node in the HUGR.
// TODO: Link the NodeHandles to the OpType.
#[non_exhaustive]
Expand Down
Loading

0 comments on commit fe3450d

Please sign in to comment.