diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 5b6fbb4657..07e940fb1d 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -1,7 +1,7 @@ import inspect import sys from abc import ABC -from typing import Any, Literal +from typing import Any, Literal, Optional from pydantic import Field, RootModel @@ -505,18 +505,20 @@ class Config: # -------------------------------------- -class OpDef(BaseOp, populate_by_name=True): +class FixedHugr(ConfiguredBaseModel): + extensions: ExtensionSet + hugr: Any + + +class OpDef(ConfiguredBaseModel, populate_by_name=True): """Serializable definition for dynamically loaded operations.""" + extension: ExtensionId name: str # Unique identifier of the operation. description: str # Human readable description of the operation. - inputs: list[tuple[str | None, Type]] - outputs: list[tuple[str | None, Type]] - misc: dict[str, Any] # Miscellaneous data associated with the operation. - def_: str | None = Field( - ..., alias="def" - ) # (YAML?)-encoded definition of the operation. - extension_reqs: ExtensionSet # Resources required to execute this operation. + misc: Optional[dict[str, Any]] = None + signature: Optional[PolyFuncType] = None + lower_funcs: list[FixedHugr] # Now that all classes are defined, we need to update the ForwardRefs in all type diff --git a/hugr-py/src/hugr/serialization/testing_hugr.py b/hugr-py/src/hugr/serialization/testing_hugr.py index 59db4b80de..98904aaf6c 100644 --- a/hugr-py/src/hugr/serialization/testing_hugr.py +++ b/hugr-py/src/hugr/serialization/testing_hugr.py @@ -1,7 +1,7 @@ from pydantic import ConfigDict from typing import Literal from .tys import Type, SumType, PolyFuncType, ConfiguredBaseModel, model_rebuild -from .ops import Value, OpType, classes as ops_classes +from .ops import Value, OpType, OpDef, classes as ops_classes class TestingHugr(ConfiguredBaseModel): @@ -14,6 +14,7 @@ class TestingHugr(ConfiguredBaseModel): poly_func_type: PolyFuncType | None = None value: Value | None = None optype: OpType | None = None + op_def: OpDef | None = None @classmethod def get_version(cls) -> str: diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index 165315fdd9..c66b78bb14 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -238,9 +238,6 @@ class Config: "a set of required resources." ) } - json_schema_extra = { - "required": ["t", "input", "output"], - } class PolyFuncType(ConfiguredBaseModel): diff --git a/hugr/src/extension.rs b/hugr/src/extension.rs index ce58221300..c8caaf251d 100644 --- a/hugr/src/extension.rs +++ b/hugr/src/extension.rs @@ -28,7 +28,7 @@ pub use infer::{ExtensionSolution, InferExtensionError}; mod op_def; pub use op_def::{ - CustomSignatureFunc, CustomValidator, OpDef, SignatureFromArgs, SignatureFunc, + CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, ValidateJustArgs, ValidateTypeArgs, }; mod type_def; @@ -552,7 +552,8 @@ impl FromIterator for ExtensionSet { } #[cfg(test)] -mod test { +pub mod test { + pub use super::op_def::test::SimpleOpDef; #[cfg(feature = "proptest")] mod proptest { diff --git a/hugr/src/extension/op_def.rs b/hugr/src/extension/op_def.rs index 1e68400ea4..3ab6b89a74 100644 --- a/hugr/src/extension/op_def.rs +++ b/hugr/src/extension/op_def.rs @@ -122,7 +122,7 @@ pub struct CustomValidator { #[serde(flatten)] poly_func: PolyFuncType, #[serde(skip)] - validate: Box, + pub(crate) validate: Box, } impl CustomValidator { @@ -265,11 +265,17 @@ impl Debug for SignatureFunc { /// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr /// that implements the operation using a set of other extensions. #[derive(serde::Deserialize, serde::Serialize)] +#[serde(untagged)] pub enum LowerFunc { /// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s, /// this will generally only be applicable if the [OpDef] has no [TypeParam]s. - #[serde(rename = "hugr")] - FixedHugr(ExtensionSet, Hugr), + FixedHugr { + /// The extensions required by the [`Hugr`] + extensions: ExtensionSet, + /// The [`Hugr`] to be used to replace [`CustomOp`]s matching the parent + /// [`OpDef`] + hugr: Hugr, + }, /// Custom binary function that can (fallibly) compute a Hugr /// for the particular instance and set of available extensions. #[serde(skip)] @@ -279,7 +285,7 @@ pub enum LowerFunc { impl Debug for LowerFunc { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::FixedHugr(_, _) => write!(f, "FixedHugr"), + Self::FixedHugr { .. } => write!(f, "FixedHugr"), Self::CustomFunc(_) => write!(f, ""), } } @@ -305,8 +311,7 @@ pub struct OpDef { signature_func: SignatureFunc, // Some operations cannot lower themselves and tools that do not understand them // can only treat them as opaque/black-box ops. - #[serde(flatten)] - lower_funcs: Vec, + pub(crate) lower_funcs: Vec, /// Operations can optionally implement [`ConstFold`] to implement constant folding. #[serde(skip)] @@ -360,9 +365,9 @@ impl OpDef { self.lower_funcs .iter() .flat_map(|f| match f { - LowerFunc::FixedHugr(req_res, h) => { - if available_extensions.is_superset(req_res) { - Some(h.clone()) + LowerFunc::FixedHugr { extensions, hugr } => { + if available_extensions.is_superset(extensions) { + Some(hugr.clone()) } else { None } @@ -464,12 +469,14 @@ impl Extension { } #[cfg(test)] -mod test { +pub mod test { use std::num::NonZeroU64; + use itertools::Itertools; + use super::SignatureFromArgs; use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; - use crate::extension::op_def::LowerFunc; + use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc}; use crate::extension::prelude::USIZE_T; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; @@ -484,6 +491,65 @@ mod test { const EXT_ID: ExtensionId = "MyExt"; } + #[derive(serde::Serialize, serde::Deserialize, Debug)] + pub struct SimpleOpDef(OpDef); + + impl PartialEq for SimpleOpDef { + fn eq(&self, other: &Self) -> bool { + let OpDef { + extension, + name, + description, + misc, + signature_func, + lower_funcs, + constant_folder, + } = &self.0; + let OpDef { + extension: other_extension, + name: other_name, + description: other_description, + misc: other_misc, + signature_func: other_signature_func, + lower_funcs: other_lower_funcs, + constant_folder: other_constant_folder, + } = &other.0; + + let get_sig = |sf: &_| match sf { + // if SignatureFunc or CustomValidator are changed we should get + // an error here, update do validate the parts of the heirarchy that + // are changed. + SignatureFunc::TypeScheme(CustomValidator { + poly_func, + validate: _, + }) => Some(poly_func.clone()), + SignatureFunc::CustomFunc(_) => None, + }; + + let get_lower_funcs = |lfs: &Vec| { + lfs.iter() + .map(|lf| match lf { + // as with get_sig above, this should break if the heirarchy + // is changed, update similarly. + LowerFunc::FixedHugr { extensions, hugr } => { + Some((extensions.clone(), hugr.clone())) + } + LowerFunc::CustomFunc(_) => None, + }) + .collect_vec() + }; + + extension == other_extension + && name == other_name + && description == other_description + && misc == other_misc + && get_sig(signature_func) == get_sig(other_signature_func) + && get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs) + && constant_folder.is_none() + && other_constant_folder.is_none() + } + } + #[test] fn op_def_with_type_scheme() -> Result<(), Box> { let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); @@ -495,7 +561,10 @@ mod test { let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var])); let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?; - def.add_lower_func(LowerFunc::FixedHugr(ExtensionSet::new(), Hugr::default())); + def.add_lower_func(LowerFunc::FixedHugr { + extensions: ExtensionSet::new(), + hugr: Hugr::default(), + }); def.add_misc("key", Default::default()); assert_eq!(def.description(), "desc"); assert_eq!(def.lower_funcs.len(), 1); @@ -662,4 +731,82 @@ mod test { ); Ok(()) } + + #[cfg(feature = "proptest")] + mod proptest { + use super::SimpleOpDef; + use ::proptest::prelude::*; + + use crate::{ + builder::test::simple_dfg_hugr, + extension::{ + op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc, + }, + types::PolyFuncType, + }; + + impl Arbitrary for SignatureFunc { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + // TODO there is also SignatureFunc::CustomFunc, but for now + // this is not serialised. When it is, we should generate + // examples here . + any::() + .prop_map(|x| SignatureFunc::TypeScheme(CustomValidator::from_polyfunc(x))) + .boxed() + } + } + + impl Arbitrary for LowerFunc { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + // TODO There is also LowerFunc::CustomFunc, but for now this is + // not serialised. When it is, we should generate examples here. + any::() + .prop_map(|extensions| LowerFunc::FixedHugr { + extensions, + hugr: simple_dfg_hugr(), + }) + .boxed() + } + } + + impl Arbitrary for SimpleOpDef { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + use crate::proptest::{any_serde_yaml_value, any_smolstr, any_string}; + use proptest::collection::{hash_map, vec}; + let signature_func: BoxedStrategy = any::(); + let lower_funcs: BoxedStrategy = any::(); + let misc = hash_map(any_string(), any_serde_yaml_value(), 0..3); + ( + any::(), + any_smolstr(), + any_string(), + misc, + signature_func, + vec(lower_funcs, 0..2), + ) + .prop_map( + |(extension, name, description, misc, signature_func, lower_funcs)| { + Self(OpDef { + extension, + name, + description, + misc, + signature_func, + lower_funcs, + // TODO ``constant_folder` is not serialised, we should + // generate examples once it is. + constant_folder: None, + }) + }, + ) + .boxed() + } + } + } } diff --git a/hugr/src/hugr/serialize/test.rs b/hugr/src/hugr/serialize/test.rs index 1ae401ce26..31fc953622 100644 --- a/hugr/src/hugr/serialize/test.rs +++ b/hugr/src/hugr/serialize/test.rs @@ -5,7 +5,7 @@ use crate::builder::{ }; use crate::extension::prelude::{BOOL_T, USIZE_T}; use crate::extension::simple_op::MakeRegisteredOp; -use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; +use crate::extension::{test::SimpleOpDef, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{dataflow::IOTrait, Input, Module, Noop, Output, DFG}; @@ -34,6 +34,7 @@ struct SerTestingV1 { poly_func_type: Option, value: Option, optype: Option, + op_def: Option, } type TestingModel = SerTestingV1; @@ -88,6 +89,7 @@ impl_sertesting_from!(crate::types::SumType, sum_type); impl_sertesting_from!(crate::types::PolyFuncType, poly_func_type); impl_sertesting_from!(crate::ops::Value, value); impl_sertesting_from!(NodeSer, optype); +impl_sertesting_from!(SimpleOpDef, op_def); #[test] fn empty_hugr_serialize() { @@ -377,8 +379,8 @@ fn serialize_types_roundtrip() { #[cfg(feature = "proptest")] mod proptest { - use super::super::NodeSer; use super::check_testing_roundtrip; + use super::{NodeSer, SimpleOpDef}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; use crate::types::{PolyFuncType, Type}; @@ -419,8 +421,13 @@ mod proptest { } #[test] - fn prop_roundtrip_optype(ns: NodeSer) { - check_testing_roundtrip(ns) + fn prop_roundtrip_optype(op: NodeSer ) { + check_testing_roundtrip(op) + } + + #[test] + fn prop_roundtrip_opdef(opdef: SimpleOpDef) { + check_testing_roundtrip(opdef) } } } diff --git a/hugr/src/ops/constant/custom.rs b/hugr/src/ops/constant/custom.rs index 9a468b7723..c98b1943f5 100644 --- a/hugr/src/ops/constant/custom.rs +++ b/hugr/src/ops/constant/custom.rs @@ -393,4 +393,38 @@ mod test { inner_deser.downcast_ref::().unwrap() ); } + + #[cfg(feature = "proptest")] + mod proptest { + use ::proptest::prelude::*; + + use crate::{ + extension::ExtensionSet, + ops::constant::CustomSerialized, + proptest::{any_serde_yaml_value, any_string}, + types::Type, + }; + + impl Arbitrary for CustomSerialized { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + let typ = any::(); + let extensions = any::(); + let value = (any_serde_yaml_value(), any_string()).prop_map(|(value, c)| { + [("c".into(), c.into()), ("v".into(), value)] + .into_iter() + .collect::() + .into() + }); + (typ, value, extensions) + .prop_map(|(typ, value, extensions)| CustomSerialized { + typ, + value, + extensions, + }) + .boxed() + } + } + } } diff --git a/hugr/src/proptest.rs b/hugr/src/proptest.rs index 749d498422..3a2c6d16f2 100644 --- a/hugr/src/proptest.rs +++ b/hugr/src/proptest.rs @@ -90,6 +90,10 @@ pub fn any_string() -> SBoxedStrategy { ANY_STRING.clone() } +pub fn any_smolstr() -> SBoxedStrategy { + ANY_STRING.clone().prop_map_into().sboxed() +} + pub fn any_serde_yaml_value() -> impl Strategy { // use serde_yaml::value::{Tag, TaggedValue, Value}; ANY_SERDE_YAML_VALUE_LEAF diff --git a/specification/schema/testing_hugr_schema_strict_v1.json b/specification/schema/testing_hugr_schema_strict_v1.json index 850d2674d2..0202bc2821 100644 --- a/specification/schema/testing_hugr_schema_strict_v1.json +++ b/specification/schema/testing_hugr_schema_strict_v1.json @@ -817,6 +817,27 @@ "title": "ExtensionsParam", "type": "object" }, + "FixedHugr": { + "additionalProperties": false, + "properties": { + "extensions": { + "items": { + "type": "string" + }, + "title": "Extensions", + "type": "array" + }, + "hugr": { + "title": "Hugr" + } + }, + "required": [ + "extensions", + "hugr" + ], + "title": "FixedHugr", + "type": "object" + }, "FuncDecl": { "additionalProperties": false, "description": "External function declaration, linked at runtime.", @@ -1363,6 +1384,62 @@ "title": "Noop", "type": "object" }, + "OpDef": { + "additionalProperties": false, + "description": "Serializable definition for dynamically loaded operations.", + "properties": { + "extension": { + "title": "Extension", + "type": "string" + }, + "name": { + "title": "Name", + "type": "string" + }, + "description": { + "title": "Description", + "type": "string" + }, + "misc": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Misc" + }, + "signature": { + "anyOf": [ + { + "$ref": "#/$defs/PolyFuncType" + }, + { + "type": "null" + } + ], + "default": null + }, + "lower_funcs": { + "items": { + "$ref": "#/$defs/FixedHugr" + }, + "title": "Lower Funcs", + "type": "array" + } + }, + "required": [ + "extension", + "name", + "description", + "lower_funcs" + ], + "title": "OpDef", + "type": "object" + }, "OpType": { "description": "A constant operation.", "discriminator": { @@ -2325,6 +2402,17 @@ } ], "default": null + }, + "op_def": { + "anyOf": [ + { + "$ref": "#/$defs/OpDef" + }, + { + "type": "null" + } + ], + "default": null } }, "title": "HugrTesting", diff --git a/specification/schema/testing_hugr_schema_v1.json b/specification/schema/testing_hugr_schema_v1.json index 40e1b894c7..40e57d4fe1 100644 --- a/specification/schema/testing_hugr_schema_v1.json +++ b/specification/schema/testing_hugr_schema_v1.json @@ -817,6 +817,27 @@ "title": "ExtensionsParam", "type": "object" }, + "FixedHugr": { + "additionalProperties": true, + "properties": { + "extensions": { + "items": { + "type": "string" + }, + "title": "Extensions", + "type": "array" + }, + "hugr": { + "title": "Hugr" + } + }, + "required": [ + "extensions", + "hugr" + ], + "title": "FixedHugr", + "type": "object" + }, "FuncDecl": { "additionalProperties": true, "description": "External function declaration, linked at runtime.", @@ -1363,6 +1384,62 @@ "title": "Noop", "type": "object" }, + "OpDef": { + "additionalProperties": true, + "description": "Serializable definition for dynamically loaded operations.", + "properties": { + "extension": { + "title": "Extension", + "type": "string" + }, + "name": { + "title": "Name", + "type": "string" + }, + "description": { + "title": "Description", + "type": "string" + }, + "misc": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Misc" + }, + "signature": { + "anyOf": [ + { + "$ref": "#/$defs/PolyFuncType" + }, + { + "type": "null" + } + ], + "default": null + }, + "lower_funcs": { + "items": { + "$ref": "#/$defs/FixedHugr" + }, + "title": "Lower Funcs", + "type": "array" + } + }, + "required": [ + "extension", + "name", + "description", + "lower_funcs" + ], + "title": "OpDef", + "type": "object" + }, "OpType": { "description": "A constant operation.", "discriminator": { @@ -2325,6 +2402,17 @@ } ], "default": null + }, + "op_def": { + "anyOf": [ + { + "$ref": "#/$defs/OpDef" + }, + { + "type": "null" + } + ], + "default": null } }, "title": "HugrTesting",