Skip to content

Commit

Permalink
opdef roundtrips
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 8, 2024
1 parent f8255b3 commit 070365b
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 31 deletions.
20 changes: 11 additions & 9 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from abc import ABC
from typing import Any, Literal
from typing import Any, Literal, Optional

from pydantic import Field, RootModel

Expand Down Expand Up @@ -505,18 +505,20 @@ class Config:
# --------------------------------------


class OpDef(BaseOp, populate_by_name=True):
class FixedHugr(ConfiguredBaseModel):
extensions: ExtensionSet
hugr: Any


class OpDef(ConfiguredBaseModel, populate_by_name=True):
"""Serializable definition for dynamically loaded operations."""

extension: ExtensionId
name: str # Unique identifier of the operation.
description: str # Human readable description of the operation.
inputs: list[tuple[str | None, Type]]
outputs: list[tuple[str | None, Type]]
misc: dict[str, Any] # Miscellaneous data associated with the operation.
def_: str | None = Field(
..., alias="def"
) # (YAML?)-encoded definition of the operation.
extension_reqs: ExtensionSet # Resources required to execute this operation.
misc: Optional[dict[str, Any]] = None
signature: Optional[PolyFuncType] = None
lower_funcs: list[FixedHugr]


# Now that all classes are defined, we need to update the ForwardRefs in all type
Expand Down
3 changes: 2 additions & 1 deletion hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import ConfigDict
from typing import Literal
from .tys import Type, SumType, PolyFuncType, ConfiguredBaseModel, model_rebuild
from .ops import Value, OpType, classes as ops_classes
from .ops import Value, OpType, OpDef, classes as ops_classes


class TestingHugr(ConfiguredBaseModel):
Expand All @@ -14,6 +14,7 @@ class TestingHugr(ConfiguredBaseModel):
poly_func_type: PolyFuncType | None = None
value: Value | None = None
optype: OpType | None = None
op_def: OpDef | None = None

@classmethod
def get_version(cls) -> str:
Expand Down
3 changes: 0 additions & 3 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,6 @@ class Config:
"a set of required resources."
)
}
json_schema_extra = {
"required": ["t", "input", "output"],
}


class PolyFuncType(ConfiguredBaseModel):
Expand Down
5 changes: 3 additions & 2 deletions hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub use infer::{ExtensionSolution, InferExtensionError};

mod op_def;
pub use op_def::{
CustomSignatureFunc, CustomValidator, OpDef, SignatureFromArgs, SignatureFunc,
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
};
mod type_def;
Expand Down Expand Up @@ -552,7 +552,8 @@ impl FromIterator<ExtensionId> for ExtensionSet {
}

#[cfg(test)]
mod test {
pub mod test {
pub use super::op_def::test::SimpleOpDef;

#[cfg(feature = "proptest")]
mod proptest {
Expand Down
171 changes: 159 additions & 12 deletions hugr/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub struct CustomValidator {
#[serde(flatten)]
poly_func: PolyFuncType,
#[serde(skip)]
validate: Box<dyn ValidateTypeArgs>,
pub(crate) validate: Box<dyn ValidateTypeArgs>,
}

impl CustomValidator {
Expand Down Expand Up @@ -265,11 +265,17 @@ impl Debug for SignatureFunc {
/// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr
/// that implements the operation using a set of other extensions.
#[derive(serde::Deserialize, serde::Serialize)]
#[serde(untagged)]
pub enum LowerFunc {
/// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s,
/// this will generally only be applicable if the [OpDef] has no [TypeParam]s.
#[serde(rename = "hugr")]
FixedHugr(ExtensionSet, Hugr),
FixedHugr {
/// The extensions required by the [`Hugr`]
extensions: ExtensionSet,
/// The [`Hugr`] to be used to replace [`CustomOp`]s matching the parent
/// [`OpDef`]
hugr: Hugr,
},
/// Custom binary function that can (fallibly) compute a Hugr
/// for the particular instance and set of available extensions.
#[serde(skip)]
Expand All @@ -279,7 +285,7 @@ pub enum LowerFunc {
impl Debug for LowerFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FixedHugr(_, _) => write!(f, "FixedHugr"),
Self::FixedHugr { .. } => write!(f, "FixedHugr"),
Self::CustomFunc(_) => write!(f, "<custom lower>"),
}
}
Expand All @@ -305,8 +311,7 @@ pub struct OpDef {
signature_func: SignatureFunc,
// Some operations cannot lower themselves and tools that do not understand them
// can only treat them as opaque/black-box ops.
#[serde(flatten)]
lower_funcs: Vec<LowerFunc>,
pub(crate) lower_funcs: Vec<LowerFunc>,

/// Operations can optionally implement [`ConstFold`] to implement constant folding.
#[serde(skip)]
Expand Down Expand Up @@ -360,9 +365,9 @@ impl OpDef {
self.lower_funcs
.iter()
.flat_map(|f| match f {
LowerFunc::FixedHugr(req_res, h) => {
if available_extensions.is_superset(req_res) {
Some(h.clone())
LowerFunc::FixedHugr { extensions, hugr } => {
if available_extensions.is_superset(extensions) {
Some(hugr.clone())
} else {
None
}
Expand Down Expand Up @@ -464,12 +469,14 @@ impl Extension {
}

#[cfg(test)]
mod test {
pub mod test {
use std::num::NonZeroU64;

use itertools::Itertools;

use super::SignatureFromArgs;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::op_def::LowerFunc;
use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc};
use crate::extension::prelude::USIZE_T;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY};
Expand All @@ -484,6 +491,65 @@ mod test {
const EXT_ID: ExtensionId = "MyExt";
}

#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct SimpleOpDef(OpDef);

impl PartialEq for SimpleOpDef {
fn eq(&self, other: &Self) -> bool {
let OpDef {
extension,
name,
description,
misc,
signature_func,
lower_funcs,
constant_folder,
} = &self.0;
let OpDef {
extension: other_extension,
name: other_name,
description: other_description,
misc: other_misc,
signature_func: other_signature_func,
lower_funcs: other_lower_funcs,
constant_folder: other_constant_folder,
} = &other.0;

let get_sig = |sf: &_| match sf {
// if SignatureFunc or CustomValidator are changed we should get
// an error here, update do validate the parts of the heirarchy that
// are changed.
SignatureFunc::TypeScheme(CustomValidator {
poly_func,
validate: _,
}) => Some(poly_func.clone()),
SignatureFunc::CustomFunc(_) => None,
};

let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
lfs.iter()
.map(|lf| match lf {
// as with get_sig above, this should break if the heirarchy
// is changed, update similarly.
LowerFunc::FixedHugr { extensions, hugr } => {
Some((extensions.clone(), hugr.clone()))
}
LowerFunc::CustomFunc(_) => None,
})
.collect_vec()
};

extension == other_extension
&& name == other_name
&& description == other_description
&& misc == other_misc
&& get_sig(signature_func) == get_sig(other_signature_func)
&& get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs)
&& constant_folder.is_none()
&& other_constant_folder.is_none()
}
}

#[test]
fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
Expand All @@ -495,7 +561,10 @@ mod test {
let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var]));

let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?;
def.add_lower_func(LowerFunc::FixedHugr(ExtensionSet::new(), Hugr::default()));
def.add_lower_func(LowerFunc::FixedHugr {
extensions: ExtensionSet::new(),
hugr: Hugr::default(),
});
def.add_misc("key", Default::default());
assert_eq!(def.description(), "desc");
assert_eq!(def.lower_funcs.len(), 1);
Expand Down Expand Up @@ -662,4 +731,82 @@ mod test {
);
Ok(())
}

#[cfg(feature = "proptest")]
mod proptest {
use super::SimpleOpDef;
use ::proptest::prelude::*;

use crate::{
builder::test::simple_dfg_hugr,
extension::{
op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc,
},
types::PolyFuncType,
};

impl Arbitrary for SignatureFunc {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
// TODO there is also SignatureFunc::CustomFunc, but for now
// this is not serialised. When it is, we should generate
// examples here .
any::<PolyFuncType>()
.prop_map(|x| SignatureFunc::TypeScheme(CustomValidator::from_polyfunc(x)))
.boxed()
}
}

impl Arbitrary for LowerFunc {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
// TODO There is also LowerFunc::CustomFunc, but for now this is
// not serialised. When it is, we should generate examples here.
any::<ExtensionSet>()
.prop_map(|extensions| LowerFunc::FixedHugr {
extensions,
hugr: simple_dfg_hugr(),
})
.boxed()
}
}

impl Arbitrary for SimpleOpDef {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use crate::proptest::{any_serde_yaml_value, any_smolstr, any_string};
use proptest::collection::{hash_map, vec};
let signature_func: BoxedStrategy<SignatureFunc> = any::<SignatureFunc>();
let lower_funcs: BoxedStrategy<LowerFunc> = any::<LowerFunc>();
let misc = hash_map(any_string(), any_serde_yaml_value(), 0..3);
(
any::<ExtensionId>(),
any_smolstr(),
any_string(),
misc,
signature_func,
vec(lower_funcs, 0..2),
)
.prop_map(
|(extension, name, description, misc, signature_func, lower_funcs)| {
Self(OpDef {
extension,
name,
description,
misc,
signature_func,
lower_funcs,
// TODO ``constant_folder` is not serialised, we should
// generate examples once it is.
constant_folder: None,
})
},
)
.boxed()
}
}
}
}
15 changes: 11 additions & 4 deletions hugr/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::builder::{
};
use crate::extension::prelude::{BOOL_T, USIZE_T};
use crate::extension::simple_op::MakeRegisteredOp;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::extension::{test::SimpleOpDef, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{dataflow::IOTrait, Input, Module, Noop, Output, DFG};
Expand Down Expand Up @@ -34,6 +34,7 @@ struct SerTestingV1 {
poly_func_type: Option<crate::types::PolyFuncType>,
value: Option<crate::ops::Value>,
optype: Option<NodeSer>,
op_def: Option<SimpleOpDef>,
}

type TestingModel = SerTestingV1;
Expand Down Expand Up @@ -88,6 +89,7 @@ impl_sertesting_from!(crate::types::SumType, sum_type);
impl_sertesting_from!(crate::types::PolyFuncType, poly_func_type);
impl_sertesting_from!(crate::ops::Value, value);
impl_sertesting_from!(NodeSer, optype);
impl_sertesting_from!(SimpleOpDef, op_def);

#[test]
fn empty_hugr_serialize() {
Expand Down Expand Up @@ -377,8 +379,8 @@ fn serialize_types_roundtrip() {

#[cfg(feature = "proptest")]
mod proptest {
use super::super::NodeSer;
use super::check_testing_roundtrip;
use super::{NodeSer, SimpleOpDef};
use crate::extension::ExtensionSet;
use crate::ops::{OpType, Value};
use crate::types::{PolyFuncType, Type};
Expand Down Expand Up @@ -419,8 +421,13 @@ mod proptest {
}

#[test]
fn prop_roundtrip_optype(ns: NodeSer) {
check_testing_roundtrip(ns)
fn prop_roundtrip_optype(op: NodeSer ) {
check_testing_roundtrip(op)
}

#[test]
fn prop_roundtrip_opdef(opdef: SimpleOpDef) {
check_testing_roundtrip(opdef)
}
}
}
Loading

0 comments on commit 070365b

Please sign in to comment.