Skip to content

Commit

Permalink
fix!: serialisation schema (#968)
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q authored Apr 30, 2024
1 parent 76dcc80 commit d913f40
Show file tree
Hide file tree
Showing 16 changed files with 1,415 additions and 234 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-rs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env:
CI: true # insta snapshots behave differently on ci
SCCACHE_GHA_ENABLED: "true"
RUSTC_WRAPPER: "sccache"
HUGR_TEST_SCHEMA: "1"

jobs:
# Check if changes were made to the relevant files.
Expand Down
11 changes: 10 additions & 1 deletion hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
PolyFuncType,
Type,
TypeRow,
SumType,
TypeBound,
)

NodeID = int
Expand Down Expand Up @@ -126,7 +128,7 @@ class SumValue(BaseModel):

c: Literal["Sum"] = Field("Sum", title="ValueTag")
tag: int
typ: Type
typ: SumType
vs: list["Value"]

class Config:
Expand Down Expand Up @@ -475,6 +477,12 @@ class Lift(DataflowOp):
new_extension: ExtensionId


class AliasDecl(BaseOp):
op: Literal["AliasDecl"] = "AliasDecl"
name: str
bound: TypeBound


class OpType(RootModel):
"""A constant operation."""

Expand All @@ -501,6 +509,7 @@ class OpType(RootModel):
| Tag
| Lift
| DFG
| AliasDecl
) = Field(discriminator="op")


Expand Down
23 changes: 23 additions & 0 deletions hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Literal, Optional
from pydantic import BaseModel
from .tys import Type, SumType, PolyFuncType
from .ops import Value


class TestingHugr(BaseModel):
"""A serializable representation of a Hugr Type, SumType, PolyFuncType, or
Value. Intended for testing only."""

version: Literal["v1"] = "v1"
typ: Optional[Type] = None
sum_type: Optional[SumType] = None
poly_func_type: Optional[PolyFuncType] = None
value: Optional[Value] = None

@classmethod
def get_version(cls) -> str:
"""Return the version of the schema."""
return cls().version

class Config:
title = "HugrTesting"
41 changes: 33 additions & 8 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,20 @@ class TupleParam(BaseModel):
params: list["TypeParam"]


class ExtensionsParam(BaseModel):
tp: Literal["Extensions"] = "Extensions"


class TypeParam(RootModel):
"""A type parameter."""

root: Annotated[
TypeTypeParam | BoundedNatParam | OpaqueParam | ListParam | TupleParam,
TypeTypeParam
| BoundedNatParam
| OpaqueParam
| ListParam
| TupleParam
| ExtensionsParam,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tp")

Expand Down Expand Up @@ -145,9 +154,7 @@ class Array(MultiContainer):


class UnitSum(BaseModel):
"""Simple predicate where all variants are empty tuples."""

t: Literal["Sum"] = "Sum"
"""Simple sum type where all variants are empty tuples."""

s: Literal["Unit"] = "Unit"
size: int
Expand All @@ -156,8 +163,6 @@ class UnitSum(BaseModel):
class GeneralSum(BaseModel):
"""General sum type that explicitly stores the types of the variants."""

t: Literal["Sum"] = "Sum"

s: Literal["General"] = "General"
rows: list["TypeRow"]

Expand All @@ -166,6 +171,11 @@ class SumType(RootModel):
root: Union[UnitSum, GeneralSum] = Field(discriminator="s")


class TaggedSumType(BaseModel):
t: Literal["Sum"] = "Sum"
st: SumType


# ----------------------------------------------
# --------------- ClassicType ------------------
# ----------------------------------------------
Expand Down Expand Up @@ -254,7 +264,7 @@ def join(*bs: "TypeBound") -> "TypeBound":


class Opaque(BaseModel):
"""An opaque operation that can be downcasted by the extensions that define it."""
"""An opaque Type that can be downcasted by the extensions that define it."""

t: Literal["Opaque"] = "Opaque"
extension: ExtensionId
Expand All @@ -263,6 +273,14 @@ class Opaque(BaseModel):
bound: TypeBound


class Alias(BaseModel):
"""An Alias Type"""

t: Literal["Alias"] = "Alias"
bound: TypeBound
name: str


# ----------------------------------------------
# --------------- LinearType -------------------
# ----------------------------------------------
Expand All @@ -278,7 +296,14 @@ class Type(RootModel):
"""A HUGR type."""

root: Annotated[
Qubit | Variable | USize | FunctionType | Array | SumType | Opaque,
Qubit
| Variable
| USize
| FunctionType
| Array
| TaggedSumType
| Opaque
| Alias,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="t")

Expand Down
47 changes: 42 additions & 5 deletions hugr/src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,29 @@ use super::{HugrMut, HugrView};
/// recent version of the format. We keep the `Deserialize` implementations for
/// older versions to allow for backwards compatibility.
///
/// The Generic `SerHugr` is always instantiated to the most recent version of
/// the format outside this module.
///
/// Make sure to order the variants from newest to oldest, as the deserializer
/// will try to deserialize them in order.
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "version", rename_all = "lowercase")]
enum Versioned {
enum Versioned<SerHugr> {
/// Version 0 of the HUGR serialization format.
V0,
/// Version 1 of the HUGR serialization format.
V1(SerHugrV1),
V1(SerHugr),

#[serde(other)]
Unsupported,
}

impl<T> Versioned<T> {
pub fn new(t: T) -> Self {
Self::V1(t)
}
}

#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct NodeSer {
parent: Node,
Expand All @@ -62,6 +71,34 @@ struct SerHugrV1 {
encoder: Option<String>,
}

/// Version 1 of the Testing HUGR serialisation format, see `testing_hugr.py`.
#[cfg(test)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
struct SerTestingV1 {
typ: Option<crate::types::Type>,
sum_type: Option<crate::types::SumType>,
poly_func_type: Option<crate::types::PolyFuncType>,
value: Option<crate::ops::Value>,
}

macro_rules! impl_sertesting_from {
($typ:ty, $field:ident) => {
#[cfg(test)]
impl From<$typ> for SerTestingV1 {
fn from(v: $typ) -> Self {
let mut r: Self = Default::default();
r.$field = Some(v);
r
}
}
};
}

impl_sertesting_from!(crate::types::Type, typ);
impl_sertesting_from!(crate::types::SumType, sum_type);
impl_sertesting_from!(crate::types::PolyFuncType, poly_func_type);
impl_sertesting_from!(crate::ops::Value, value);

/// Errors that can occur while serializing a HUGR.
#[derive(Debug, Clone, PartialEq, Error)]
#[non_exhaustive]
Expand Down Expand Up @@ -99,7 +136,7 @@ impl Serialize for Hugr {
S: serde::Serializer,
{
let shg: SerHugrV1 = self.try_into().map_err(serde::ser::Error::custom)?;
let versioned = Versioned::V1(shg);
let versioned = Versioned::new(shg);
versioned.serialize(serializer)
}
}
Expand All @@ -109,7 +146,7 @@ impl<'de> Deserialize<'de> for Hugr {
where
D: Deserializer<'de>,
{
let shg = Versioned::deserialize(deserializer)?;
let shg: Versioned<SerHugrV1> = Versioned::deserialize(deserializer)?;
match shg {
Versioned::V0 => Err(serde::de::Error::custom(
"Version 0 HUGR serialization format is not supported.",
Expand Down
83 changes: 81 additions & 2 deletions hugr/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@ use crate::builder::{
test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, USIZE_T};
use crate::extension::prelude::{BOOL_T, QB_T, USIZE_T};
use crate::extension::simple_op::MakeRegisteredOp;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::NodeType;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::Value;
use crate::ops::{dataflow::IOTrait, Input, Module, Noop, Output, DFG};
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES};
use crate::std_extensions::logic::NotOp;
use crate::types::{FunctionType, Type};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType, SumType, Type, TypeBound};
use crate::{type_row, OutgoingPort};
use itertools::Itertools;
use jsonschema::{Draft, JSONSchema};
Expand All @@ -22,10 +25,13 @@ use portgraph::LinkView;
use portgraph::{
multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, PortView, UnmanagedDenseMap,
};
use rstest::rstest;

const NAT: Type = crate::extension::prelude::USIZE_T;
const QB: Type = crate::extension::prelude::QB_T;

type TestingModel = SerTestingV1;

lazy_static! {
static ref SCHEMA: JSONSchema = {
let schema_val: serde_json::Value = serde_json::from_str(include_str!(
Expand All @@ -37,6 +43,16 @@ lazy_static! {
.compile(&schema_val)
.expect("Schema is invalid.")
};
static ref TESTING_SCHEMA: JSONSchema = {
let schema_val: serde_json::Value = serde_json::from_str(include_str!(
"../../../../specification/schema/testing_hugr_schema_v1.json"
))
.unwrap();
JSONSchema::options()
.with_draft(Draft::Draft7)
.compile(&schema_val)
.expect("Schema is invalid.")
};
}

#[test]
Expand Down Expand Up @@ -124,6 +140,12 @@ pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr {
new_hugr
}

fn check_testing_roundtrip(t: TestingModel) {
let before = Versioned::new(t);
let after = ser_roundtrip_validate(&before, Some(&TESTING_SCHEMA));
assert_eq!(before, after);
}

/// Generate an optype for a node with a matching amount of inputs and outputs.
fn gen_optype(g: &MultiPortGraph, node: portgraph::NodeIndex) -> OpType {
let inputs = g.num_inputs(node);
Expand Down Expand Up @@ -312,3 +334,60 @@ fn serialize_types_roundtrip() {
let t = Type::new_unit_sum(4);
assert_eq!(ser_roundtrip(&t), t);
}

#[rstest]
#[case(BOOL_T)]
#[case(USIZE_T)]
#[case(INT_TYPES[2].clone())]
#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Any)))]
#[case(Type::new_var_use(2, TypeBound::Copyable))]
#[case(Type::new_tuple(type_row![BOOL_T,QB_T]))]
#[case(Type::new_sum([type_row![BOOL_T,QB_T], type_row![Type::new_unit_sum(4)]]))]
#[case(Type::new_function(FunctionType::new_endo(type_row![QB_T,BOOL_T,USIZE_T])))]
fn roundtrip_type(#[case] typ: Type) {
check_testing_roundtrip(typ.into())
}

#[rstest]
#[case(SumType::new_unary(2))]
#[case(SumType::new([type_row![USIZE_T, QB_T], type_row![]]))]
fn roundtrip_sumtype(#[case] sum_type: SumType) {
check_testing_roundtrip(sum_type.into())
}

#[rstest]
#[case(Value::unit())]
#[case(Value::true_val())]
#[case(Value::unit_sum(3,5).unwrap())]
#[case(Value::extension(ConstF64::new(-1.5)))]
#[case(Value::extension(ConstF64::new(0.0)))]
#[case(Value::extension(ConstF64::new(-0.0)))]
// These cases fail
// #[case(Value::extension(ConstF64::new(std::f64::NAN)))]
// #[case(Value::extension(ConstF64::new(std::f64::INFINITY)))]
// #[case(Value::extension(ConstF64::new(std::f64::NEG_INFINITY)))]
#[case(Value::extension(ConstF64::new(std::f64::MIN_POSITIVE)))]
#[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())]
#[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))]
#[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())]
fn roundtrip_value(#[case] value: Value) {
check_testing_roundtrip(value.into())
}

fn polyfunctype1() -> PolyFuncType {
let mut extension_set = ExtensionSet::new();
extension_set.insert_type_var(1);
let function_type = FunctionType::new_endo(type_row![]).with_extension_delta(extension_set);
PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type)
}

#[rstest]
#[case(FunctionType::new_endo(type_row![]).into())]
#[case(polyfunctype1())]
#[case(PolyFuncType::new([TypeParam::Opaque { ty: int_custom_type(TypeArg::BoundedNat { n: 1 }) }], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))]
#[case(PolyFuncType::new([TypeBound::Eq.into()], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Eq)])))]
#[case(PolyFuncType::new([TypeParam::List { param: Box::new(TypeBound::Any.into()) }], FunctionType::new_endo(type_row![])))]
#[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FunctionType::new_endo(type_row![])))]
fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) {
check_testing_roundtrip(poly_func_type.into())
}
Loading

0 comments on commit d913f40

Please sign in to comment.