diff --git a/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__dot_cfg.snap b/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__dot_cfg.snap
index c07f68144..e81a393d8 100644
--- a/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__dot_cfg.snap
+++ b/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__dot_cfg.snap
@@ -18,7 +18,7 @@ digraph {
7 [shape=plain label=<
>]
7:out0 -> 8:in1 [style=""]
8 [shape=plain label=<>]
-9 [shape=plain label=<(9) const:sum:{tag:0, vals:[]} |
0: [] |
>]
+9 [shape=plain label=<>]
9:out0 -> 10:in0 [style=""]
10 [shape=plain label=<0: [] |
(10) LoadConstant |
0: [] |
>]
10:out0 -> 8:in0 [style=""]
@@ -44,3 +44,4 @@ hier8 [shape=plain label="8"]
hier9 [shape=plain label="9"]
hier10 [shape=plain label="10"]
}
+
diff --git a/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__mmd_cfg.snap b/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__mmd_cfg.snap
index e70c19256..99b546a22 100644
--- a/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__mmd_cfg.snap
+++ b/hugr-core/src/hugr/views/snapshots/hugr_core__hugr__views__tests__mmd_cfg.snap
@@ -21,10 +21,11 @@ graph LR
7["(7) Input"]
7--"0:1
usize"-->8
8["(8) Output"]
- 9["(9) const:sum:{tag:0, vals:[]}"]
+ 9["(9) const:seq:{}"]
9--"0:0
[]"-->10
10["(10) LoadConstant"]
10--"0:0
[]"-->8
end
6-."0:0".->1
end
+
diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs
index 1cb13815f..696d3c89e 100644
--- a/hugr-core/src/ops/constant.rs
+++ b/hugr-core/src/ops/constant.rs
@@ -101,6 +101,76 @@ impl AsRef for Const {
}
}
+#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
+struct SerialSum {
+ #[serde(default)]
+ tag: usize,
+ #[serde(rename = "vs")]
+ values: Vec,
+ #[serde(default, rename = "typ")]
+ sum_type: Option,
+}
+
+#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
+#[serde(try_from = "SerialSum")]
+#[serde(into = "SerialSum")]
+/// A Sum variant, with a tag indicating the index of the variant and its
+/// value.
+pub struct Sum {
+ /// The tag index of the variant.
+ pub tag: usize,
+ /// The value of the variant.
+ ///
+ /// Sum variants are always a row of values, hence the Vec.
+ pub values: Vec,
+ /// The full type of the Sum, including the other variants.
+ pub sum_type: SumType,
+}
+
+impl Sum {
+ /// If value is a sum with a single row variant, return the row.
+ pub fn as_tuple(&self) -> Option<&[Value]> {
+ self.sum_type.as_tuple().map(|_| self.values.as_ref())
+ }
+}
+
+impl TryFrom for Sum {
+ type Error = &'static str;
+
+ fn try_from(value: SerialSum) -> Result {
+ let SerialSum {
+ tag,
+ values,
+ sum_type,
+ } = value;
+
+ let sum_type = if let Some(sum_type) = sum_type {
+ sum_type
+ } else {
+ if tag != 0 {
+ return Err("Sum type must be provided if tag is not 0");
+ }
+ SumType::new_tuple(values.iter().map(Value::get_type).collect_vec())
+ };
+
+ Ok(Self {
+ tag,
+ values,
+ sum_type,
+ })
+ }
+}
+
+impl From for SerialSum {
+ fn from(value: Sum) -> Self {
+ Self {
+ tag: value.tag,
+ values: value.values,
+ sum_type: Some(value.sum_type),
+ }
+ }
+}
+
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "v")]
/// A value that can be stored as a static constant. Representing core types and
@@ -118,25 +188,10 @@ pub enum Value {
/// A Hugr defining the function.
hugr: Box,
},
- /// A tuple
- Tuple {
- /// Constant values in the tuple.
- vs: Vec,
- },
/// A Sum variant, with a tag indicating the index of the variant and its
/// value.
- Sum {
- /// The tag index of the variant.
- tag: usize,
- /// The value of the variant.
- ///
- /// Sum variants are always a row of values, hence the Vec.
- #[serde(rename = "vs")]
- values: Vec,
- /// The full type of the Sum, including the other variants.
- #[serde(rename = "typ")]
- sum_type: SumType,
- },
+ #[serde(alias = "Tuple")]
+ Sum(Sum),
}
/// An opaque newtype around a [`Box`](CustomConst).
@@ -286,8 +341,7 @@ impl Value {
pub fn get_type(&self) -> Type {
match self {
Self::Extension { e } => e.get_type(),
- Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::get_type).collect_vec()),
- Self::Sum { sum_type, .. } => sum_type.clone().into(),
+ Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(),
Self::Function { hugr } => {
let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e));
Type::new_function(func_type)
@@ -305,18 +359,19 @@ impl Value {
) -> Result {
let values: Vec = items.into_iter().collect();
typ.check_type(tag, &values)?;
- Ok(Self::Sum {
+ Ok(Self::Sum(Sum {
tag,
values,
sum_type: typ,
- })
+ }))
}
/// Returns a tuple constant of constant values.
pub fn tuple(items: impl IntoIterator- ) -> Self {
- Self::Tuple {
- vs: items.into_iter().collect(),
- }
+ let vs = items.into_iter().collect_vec();
+ let tys = vs.iter().map(Self::get_type).collect_vec();
+
+ Self::sum(0, vs, SumType::new_tuple(tys)).expect("Tuple type is valid")
}
/// Returns a constant function defined by a Hugr.
@@ -334,7 +389,11 @@ impl Value {
/// Returns a constant unit type (empty Tuple).
pub const fn unit() -> Self {
- Self::Tuple { vs: vec![] }
+ Self::Sum(Sum {
+ tag: 0,
+ values: vec![],
+ sum_type: SumType::Unit { size: 1 },
+ })
}
/// Returns a constant Sum over units. Used as branching values.
@@ -393,12 +452,17 @@ impl Value {
};
format!("const:function:[{}]", t)
}
- Self::Tuple { vs: vals } => {
- let names: Vec<_> = vals.iter().map(Value::name).collect();
- format!("const:seq:{{{}}}", names.iter().join(", "))
- }
- Self::Sum { tag, values, .. } => {
- format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
+ Self::Sum(Sum {
+ tag,
+ values,
+ sum_type,
+ }) => {
+ if sum_type.as_tuple().is_some() {
+ let names: Vec<_> = values.iter().map(Value::name).collect();
+ format!("const:seq:{{{}}}", names.iter().join(", "))
+ } else {
+ format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
+ }
}
}
.into()
@@ -409,8 +473,7 @@ impl Value {
match self {
Self::Extension { e } => e.extension_reqs().clone(),
Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run)
- Self::Tuple { vs } => ExtensionSet::union_over(vs.iter().map(Value::extension_reqs)),
- Self::Sum { values, .. } => {
+ Self::Sum(Sum { values, .. }) => {
ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs()))
}
}
@@ -424,22 +487,25 @@ impl Value {
mono_fn_type(hugr)?;
Ok(())
}
- Self::Tuple { vs } => {
- for v in vs {
- v.validate()?;
- }
- Ok(())
- }
- Self::Sum {
+ Self::Sum(Sum {
tag,
values,
sum_type,
- } => {
+ }) => {
sum_type.check_type(*tag, values)?;
Ok(())
}
}
}
+
+ /// If value is a sum with a single row variant, return the row.
+ pub fn as_tuple(&self) -> Option<&[Value]> {
+ if let Self::Sum(sum) = self {
+ sum.as_tuple()
+ } else {
+ None
+ }
+ }
}
impl From for Value
@@ -662,7 +728,7 @@ mod test {
}
mod proptest {
- use super::super::OpaqueValue;
+ use super::super::{OpaqueValue, Sum};
use crate::{
ops::{constant::CustomSerialized, Value},
std_extensions::arithmetic::int_types::ConstInt,
@@ -715,7 +781,7 @@ mod test {
3, // Each collection is up to 3 elements long
|element| {
prop_oneof![
- vec(element.clone(), 0..3).prop_map(|vs| Self::Tuple { vs }),
+ vec(element.clone(), 0..3).prop_map(Self::tuple),
(
any::(),
vec(element.clone(), 0..3),
@@ -723,11 +789,11 @@ mod test {
)
.prop_map(
|(tag, values, sum_type)| {
- Self::Sum {
+ Self::Sum(Sum {
tag,
values,
sum_type,
- }
+ })
}
),
]
@@ -737,4 +803,75 @@ mod test {
}
}
}
+
+ #[test]
+ fn test_tuple_deserialize() {
+ let json = r#"
+ {
+ "v": "Tuple",
+ "vs": [
+ {
+ "v": "Sum",
+ "tag": 0,
+ "typ": {
+ "t": "Sum",
+ "s": "Unit",
+ "size": 1
+ },
+ "vs": []
+ },
+ {
+ "v": "Sum",
+ "tag": 1,
+ "typ": {
+ "t": "Sum",
+ "s": "General",
+ "rows": [
+ [
+ {
+ "t": "Sum",
+ "s": "Unit",
+ "size": 1
+ }
+ ],
+ [
+ {
+ "t": "Sum",
+ "s": "Unit",
+ "size": 2
+ }
+ ]
+ ]
+ },
+ "vs": [
+ {
+ "v": "Sum",
+ "tag": 1,
+ "typ": {
+ "t": "Sum",
+ "s": "Unit",
+ "size": 2
+ },
+ "vs": []
+ }
+ ]
+ }
+ ]
+}
+ "#;
+
+ let v: Value = serde_json::from_str(json).unwrap();
+ assert_eq!(
+ v,
+ Value::tuple([
+ Value::unit(),
+ Value::sum(
+ 1,
+ [Value::true_val()],
+ SumType::new([vec![Type::UNIT], vec![Value::true_val().get_type()]]),
+ )
+ .unwrap()
+ ])
+ );
+ }
}
diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs
index cb8cd0706..3bb0eafbd 100644
--- a/hugr-core/src/types.rs
+++ b/hugr-core/src/types.rs
@@ -169,6 +169,11 @@ impl SumType {
Self::Unit { size }
}
+ /// New tuple (single row of variants)
+ pub fn new_tuple(types: impl Into) -> Self {
+ Self::new([types.into()])
+ }
+
/// Report the tag'th variant, if it exists.
pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> {
match self {
@@ -185,6 +190,15 @@ impl SumType {
SumType::General { rows } => rows.len(),
}
}
+
+ /// Returns variant row if there is only one variant
+ pub fn as_tuple(&self) -> Option<&TypeRow> {
+ match self {
+ SumType::Unit { size } if *size == 1 => Some(Type::EMPTY_TYPEROW_REF),
+ SumType::General { rows } if rows.len() == 1 => Some(&rows[0]),
+ _ => None,
+ }
+ }
}
impl From for Type {
diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs
index 4f4f44f6b..77cc3748e 100644
--- a/hugr-passes/src/const_fold.rs
+++ b/hugr-passes/src/const_fold.rs
@@ -109,7 +109,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
}
OpType::UnpackTuple { .. } => {
let c = &consts.first()?.1;
- let Value::Tuple { vs } = c else {
+ let Some(vs) = c.as_tuple() else {
panic!("This op always takes a Tuple input.");
};
out_row(vs.iter().cloned())