diff --git a/hugr/src/hugr/serialize.rs b/hugr/src/hugr/serialize.rs index a753f0998..1cf7e317e 100644 --- a/hugr/src/hugr/serialize.rs +++ b/hugr/src/hugr/serialize.rs @@ -70,10 +70,12 @@ pub enum HUGRSerializationError { #[error("Failed to build edge when deserializing: {0:?}.")] LinkError(#[from] LinkError), /// Edges without port offsets cannot be present in operations without non-dataflow ports. - #[error("Cannot connect an edge without port offset to node {node:?} with operation type {op_type:?}.")] + #[error("Cannot connect an {dir:?} edge without port offset to node {node:?} with operation type {op_type:?}.")] MissingPortOffset { /// The node that has the port without offset. node: Node, + /// The direction of the port without an offset + dir: Direction, /// The operation type of the node. op_type: OpType, }, @@ -232,6 +234,7 @@ impl TryFrom for Hugr { .other_port(dir) .ok_or(HUGRSerializationError::MissingPortOffset { node, + dir, op_type: op_type.clone(), })? .index() @@ -329,10 +332,20 @@ pub mod test { } /// Serialize and deserialize a HUGR, and check that the result is the same as the original. + /// Checks the serialized json against the in-tree schema. /// /// Returns the deserialized HUGR. - pub fn check_hugr_roundtrip(hugr: &Hugr) -> Hugr { - let new_hugr: Hugr = ser_roundtrip_validate(hugr, Some(&SCHEMA)); + pub fn check_hugr_schema_roundtrip(hugr: &Hugr) -> Hugr { + check_hugr_roundtrip(hugr, true) + } + + /// Serialize and deserialize a HUGR, and check that the result is the same as the original. + /// + /// If `check_schema` is true, checks the serialized json against the in-tree schema. + /// + /// Returns the deserialized HUGR. + pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr { + let new_hugr: Hugr = ser_roundtrip_validate(hugr, check_schema.then_some(&SCHEMA)); // Original HUGR, with canonicalized node indices // @@ -418,7 +431,7 @@ pub mod test { metadata: Default::default(), }; - check_hugr_roundtrip(&hugr); + check_hugr_schema_roundtrip(&hugr); } #[test] @@ -452,7 +465,7 @@ pub mod test { module_builder.finish_prelude_hugr().unwrap() }; - check_hugr_roundtrip(&hugr); + check_hugr_schema_roundtrip(&hugr); } #[test] @@ -468,7 +481,7 @@ pub mod test { } let hugr = dfg.finish_hugr_with_outputs(params, &EMPTY_REG)?; - check_hugr_roundtrip(&hugr); + check_hugr_schema_roundtrip(&hugr); Ok(()) } @@ -491,7 +504,7 @@ pub mod test { let hugr = dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY)?; - check_hugr_roundtrip(&hugr); + check_hugr_schema_roundtrip(&hugr); Ok(()) } @@ -502,7 +515,7 @@ pub mod test { let op = bldr.add_dataflow_op(Noop { ty: fn_ty }, bldr.input_wires())?; let h = bldr.finish_prelude_hugr_with_outputs(op.outputs())?; - check_hugr_roundtrip(&h); + check_hugr_schema_roundtrip(&h); Ok(()) } @@ -520,7 +533,7 @@ pub mod test { hugr.remove_node(old_in); hugr.update_validate(&PRELUDE_REGISTRY)?; - let new_hugr: Hugr = check_hugr_roundtrip(&hugr); + let new_hugr: Hugr = check_hugr_schema_roundtrip(&hugr); new_hugr.validate(&EMPTY_REG).unwrap_err(); new_hugr.validate(&PRELUDE_REGISTRY)?; Ok(()) diff --git a/hugr/src/hugr/validate.rs b/hugr/src/hugr/validate.rs index 86e5eb24a..4343d2765 100644 --- a/hugr/src/hugr/validate.rs +++ b/hugr/src/hugr/validate.rs @@ -98,6 +98,15 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // Hierarchy and children. No type variables declared outside the root. self.validate_subtree(self.hugr.root(), &[])?; + // In tests we take the opportunity to verify that the hugr + // serialization round-trips. + // + // TODO: We should also verify that the serialized hugr matches the + // in-tree schema. For now, our serialized hugr does not match the + // schema. When this is fixed we should pass true below. + #[cfg(test)] + crate::hugr::serialize::test::check_hugr_roundtrip(self.hugr, false); + Ok(()) } diff --git a/hugr/src/ops.rs b/hugr/src/ops.rs index 07a59d116..f20d67c79 100644 --- a/hugr/src/ops.rs +++ b/hugr/src/ops.rs @@ -189,13 +189,13 @@ impl OpType { /// /// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports. pub fn other_port(&self, dir: Direction) -> Option { + let df_count = self.value_port_count(dir); let non_df_count = self.non_df_port_count(dir); - if self.other_port_kind(dir).is_some() && non_df_count == 1 { - // if there is a static input it comes before the non_df_ports - let static_input = - (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; - - Some(Port::new(dir, self.value_port_count(dir) + static_input)) + // if there is a static input it comes before the non_df_ports + let static_input = + (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; + if self.other_port_kind(dir).is_some() && non_df_count >= 1 { + Some(Port::new(dir, df_count + static_input)) } else { None } diff --git a/hugr/src/ops/constant.rs b/hugr/src/ops/constant.rs index 5643ccab4..0676e7017 100644 --- a/hugr/src/ops/constant.rs +++ b/hugr/src/ops/constant.rs @@ -305,7 +305,7 @@ mod test { use super::*; - #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// A custom constant value used in testing pub(crate) struct CustomTestValue(pub CustomType); @@ -322,6 +322,10 @@ mod test { fn get_type(&self) -> Type { self.0.clone().into() } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } } /// A [`CustomSerialized`] encoding a [`FLOAT64_TYPE`] float constant used in testing.