diff --git a/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__dot_cfg.snap b/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__dot_cfg.snap index c07f68144..e81a393d8 100644 --- a/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__dot_cfg.snap +++ b/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__dot_cfg.snap @@ -18,7 +18,7 @@ digraph { 7 [shape=plain label=<
(7) Input
0: usize
>] 7:out0 -> 8:in1 [style=""] 8 [shape=plain label=<
0: []1: usize
(8) Output
>] -9 [shape=plain label=<
(9) const:sum:{tag:0, vals:[]}
0: []
>] +9 [shape=plain label=<
(9) const:seq:{}
0: []
>] 9:out0 -> 10:in0 [style=""] 10 [shape=plain label=<
0: []
(10) LoadConstant
0: []
>] 10:out0 -> 8:in0 [style=""] @@ -44,3 +44,4 @@ hier8 [shape=plain label="8"] hier9 [shape=plain label="9"] hier10 [shape=plain label="10"] } + diff --git a/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__mmd_cfg.snap b/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__mmd_cfg.snap index e70c19256..99b546a22 100644 --- a/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__mmd_cfg.snap +++ b/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__mmd_cfg.snap @@ -21,10 +21,11 @@ graph LR 7["(7) Input"] 7--"0:1
usize"-->8 8["(8) Output"] - 9["(9) const:sum:{tag:0, vals:[]}"] + 9["(9) const:seq:{}"] 9--"0:0
[]"-->10 10["(10) LoadConstant"] 10--"0:0
[]"-->8 end 6-."0:0".->1 end + diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 1cb13815f..696d3c89e 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -101,6 +101,76 @@ impl AsRef for Const { } } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +struct SerialSum { + #[serde(default)] + tag: usize, + #[serde(rename = "vs")] + values: Vec, + #[serde(default, rename = "typ")] + sum_type: Option, +} + +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(try_from = "SerialSum")] +#[serde(into = "SerialSum")] +/// A Sum variant, with a tag indicating the index of the variant and its +/// value. +pub struct Sum { + /// The tag index of the variant. + pub tag: usize, + /// The value of the variant. + /// + /// Sum variants are always a row of values, hence the Vec. + pub values: Vec, + /// The full type of the Sum, including the other variants. + pub sum_type: SumType, +} + +impl Sum { + /// If value is a sum with a single row variant, return the row. + pub fn as_tuple(&self) -> Option<&[Value]> { + self.sum_type.as_tuple().map(|_| self.values.as_ref()) + } +} + +impl TryFrom for Sum { + type Error = &'static str; + + fn try_from(value: SerialSum) -> Result { + let SerialSum { + tag, + values, + sum_type, + } = value; + + let sum_type = if let Some(sum_type) = sum_type { + sum_type + } else { + if tag != 0 { + return Err("Sum type must be provided if tag is not 0"); + } + SumType::new_tuple(values.iter().map(Value::get_type).collect_vec()) + }; + + Ok(Self { + tag, + values, + sum_type, + }) + } +} + +impl From for SerialSum { + fn from(value: Sum) -> Self { + Self { + tag: value.tag, + values: value.values, + sum_type: Some(value.sum_type), + } + } +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[serde(tag = "v")] /// A value that can be stored as a static constant. Representing core types and @@ -118,25 +188,10 @@ pub enum Value { /// A Hugr defining the function. hugr: Box, }, - /// A tuple - Tuple { - /// Constant values in the tuple. - vs: Vec, - }, /// A Sum variant, with a tag indicating the index of the variant and its /// value. - Sum { - /// The tag index of the variant. - tag: usize, - /// The value of the variant. - /// - /// Sum variants are always a row of values, hence the Vec. - #[serde(rename = "vs")] - values: Vec, - /// The full type of the Sum, including the other variants. - #[serde(rename = "typ")] - sum_type: SumType, - }, + #[serde(alias = "Tuple")] + Sum(Sum), } /// An opaque newtype around a [`Box`](CustomConst). @@ -286,8 +341,7 @@ impl Value { pub fn get_type(&self) -> Type { match self { Self::Extension { e } => e.get_type(), - Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::get_type).collect_vec()), - Self::Sum { sum_type, .. } => sum_type.clone().into(), + Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(), Self::Function { hugr } => { let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e)); Type::new_function(func_type) @@ -305,18 +359,19 @@ impl Value { ) -> Result { let values: Vec = items.into_iter().collect(); typ.check_type(tag, &values)?; - Ok(Self::Sum { + Ok(Self::Sum(Sum { tag, values, sum_type: typ, - }) + })) } /// Returns a tuple constant of constant values. pub fn tuple(items: impl IntoIterator) -> Self { - Self::Tuple { - vs: items.into_iter().collect(), - } + let vs = items.into_iter().collect_vec(); + let tys = vs.iter().map(Self::get_type).collect_vec(); + + Self::sum(0, vs, SumType::new_tuple(tys)).expect("Tuple type is valid") } /// Returns a constant function defined by a Hugr. @@ -334,7 +389,11 @@ impl Value { /// Returns a constant unit type (empty Tuple). pub const fn unit() -> Self { - Self::Tuple { vs: vec![] } + Self::Sum(Sum { + tag: 0, + values: vec![], + sum_type: SumType::Unit { size: 1 }, + }) } /// Returns a constant Sum over units. Used as branching values. @@ -393,12 +452,17 @@ impl Value { }; format!("const:function:[{}]", t) } - Self::Tuple { vs: vals } => { - let names: Vec<_> = vals.iter().map(Value::name).collect(); - format!("const:seq:{{{}}}", names.iter().join(", ")) - } - Self::Sum { tag, values, .. } => { - format!("const:sum:{{tag:{tag}, vals:{values:?}}}") + Self::Sum(Sum { + tag, + values, + sum_type, + }) => { + if sum_type.as_tuple().is_some() { + let names: Vec<_> = values.iter().map(Value::name).collect(); + format!("const:seq:{{{}}}", names.iter().join(", ")) + } else { + format!("const:sum:{{tag:{tag}, vals:{values:?}}}") + } } } .into() @@ -409,8 +473,7 @@ impl Value { match self { Self::Extension { e } => e.extension_reqs().clone(), Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run) - Self::Tuple { vs } => ExtensionSet::union_over(vs.iter().map(Value::extension_reqs)), - Self::Sum { values, .. } => { + Self::Sum(Sum { values, .. }) => { ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs())) } } @@ -424,22 +487,25 @@ impl Value { mono_fn_type(hugr)?; Ok(()) } - Self::Tuple { vs } => { - for v in vs { - v.validate()?; - } - Ok(()) - } - Self::Sum { + Self::Sum(Sum { tag, values, sum_type, - } => { + }) => { sum_type.check_type(*tag, values)?; Ok(()) } } } + + /// If value is a sum with a single row variant, return the row. + pub fn as_tuple(&self) -> Option<&[Value]> { + if let Self::Sum(sum) = self { + sum.as_tuple() + } else { + None + } + } } impl From for Value @@ -662,7 +728,7 @@ mod test { } mod proptest { - use super::super::OpaqueValue; + use super::super::{OpaqueValue, Sum}; use crate::{ ops::{constant::CustomSerialized, Value}, std_extensions::arithmetic::int_types::ConstInt, @@ -715,7 +781,7 @@ mod test { 3, // Each collection is up to 3 elements long |element| { prop_oneof![ - vec(element.clone(), 0..3).prop_map(|vs| Self::Tuple { vs }), + vec(element.clone(), 0..3).prop_map(Self::tuple), ( any::(), vec(element.clone(), 0..3), @@ -723,11 +789,11 @@ mod test { ) .prop_map( |(tag, values, sum_type)| { - Self::Sum { + Self::Sum(Sum { tag, values, sum_type, - } + }) } ), ] @@ -737,4 +803,75 @@ mod test { } } } + + #[test] + fn test_tuple_deserialize() { + let json = r#" + { + "v": "Tuple", + "vs": [ + { + "v": "Sum", + "tag": 0, + "typ": { + "t": "Sum", + "s": "Unit", + "size": 1 + }, + "vs": [] + }, + { + "v": "Sum", + "tag": 1, + "typ": { + "t": "Sum", + "s": "General", + "rows": [ + [ + { + "t": "Sum", + "s": "Unit", + "size": 1 + } + ], + [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ] + ] + }, + "vs": [ + { + "v": "Sum", + "tag": 1, + "typ": { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + "vs": [] + } + ] + } + ] +} + "#; + + let v: Value = serde_json::from_str(json).unwrap(); + assert_eq!( + v, + Value::tuple([ + Value::unit(), + Value::sum( + 1, + [Value::true_val()], + SumType::new([vec![Type::UNIT], vec![Value::true_val().get_type()]]), + ) + .unwrap() + ]) + ); + } } diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index cb8cd0706..3bb0eafbd 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -169,6 +169,11 @@ impl SumType { Self::Unit { size } } + /// New tuple (single row of variants) + pub fn new_tuple(types: impl Into) -> Self { + Self::new([types.into()]) + } + /// Report the tag'th variant, if it exists. pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> { match self { @@ -185,6 +190,15 @@ impl SumType { SumType::General { rows } => rows.len(), } } + + /// Returns variant row if there is only one variant + pub fn as_tuple(&self) -> Option<&TypeRow> { + match self { + SumType::Unit { size } if *size == 1 => Some(Type::EMPTY_TYPEROW_REF), + SumType::General { rows } if rows.len() == 1 => Some(&rows[0]), + _ => None, + } + } } impl From for Type { diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 4f4f44f6b..77cc3748e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -109,7 +109,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR } OpType::UnpackTuple { .. } => { let c = &consts.first()?.1; - let Value::Tuple { vs } = c else { + let Some(vs) = c.as_tuple() else { panic!("This op always takes a Tuple input."); }; out_row(vs.iter().cloned())