Skip to content

Commit

Permalink
refactor!: remove Value::Tuple
Browse files Browse the repository at this point in the history
uses new serialisation intermediary to keep serialized Tuple backwards compatibility

BREAKING: `Value::Sum` now holds a standalone struct `Sum`
  • Loading branch information
ss2165 committed Jul 2, 2024
1 parent 8f08b8c commit df3bddd
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ digraph {
7 [shape=plain label=<<table border="1"><tr><td align="text" border="0" colspan="1">(7) Input</td></tr><tr><td port="out0" align="text" colspan="1" cellpadding="1" >0: usize</td></tr></table>>]
7:out0 -> 8:in1 [style=""]
8 [shape=plain label=<<table border="1"><tr><td port="in0" align="text" colspan="1" cellpadding="1" >0: []</td><td port="in1" align="text" colspan="1" cellpadding="1" >1: usize</td></tr><tr><td align="text" border="0" colspan="2">(8) Output</td></tr></table>>]
9 [shape=plain label=<<table border="1"><tr><td align="text" border="0" colspan="1">(9) const:sum:{tag:0, vals:[]}</td></tr><tr><td port="out0" align="text" colspan="1" cellpadding="1" >0: []</td></tr></table>>]
9 [shape=plain label=<<table border="1"><tr><td align="text" border="0" colspan="1">(9) const:seq:{}</td></tr><tr><td port="out0" align="text" colspan="1" cellpadding="1" >0: []</td></tr></table>>]
9:out0 -> 10:in0 [style=""]
10 [shape=plain label=<<table border="1"><tr><td port="in0" align="text" colspan="1" cellpadding="1" >0: []</td></tr><tr><td align="text" border="0" colspan="1">(10) LoadConstant</td></tr><tr><td port="out0" align="text" colspan="1" cellpadding="1" >0: []</td></tr></table>>]
10:out0 -> 8:in0 [style=""]
Expand All @@ -44,3 +44,4 @@ hier8 [shape=plain label="8"]
hier9 [shape=plain label="9"]
hier10 [shape=plain label="10"]
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ graph LR
7["(7) Input"]
7--"0:1<br>usize"-->8
8["(8) Output"]
9["(9) const:sum:{tag:0, vals:[]}"]
9["(9) const:seq:{}"]
9--"0:0<br>[]"-->10
10["(10) LoadConstant"]
10--"0:0<br>[]"-->8
end
6-."0:0".->1
end

150 changes: 105 additions & 45 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,69 @@ impl AsRef<Value> for Const {
}
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
struct SerialSum {
#[serde(default)]
tag: usize,
#[serde(rename = "vs")]
values: Vec<Value>,
#[serde(default, rename = "typ")]
sum_type: Option<SumType>,
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(try_from = "SerialSum")]
#[serde(into = "SerialSum")]
/// A Sum variant, with a tag indicating the index of the variant and its
/// value.
pub struct Sum {
/// The tag index of the variant.
pub tag: usize,
/// The value of the variant.
///
/// Sum variants are always a row of values, hence the Vec.
pub values: Vec<Value>,
/// The full type of the Sum, including the other variants.
pub sum_type: SumType,
}

impl TryFrom<SerialSum> for Sum {
type Error = &'static str;

fn try_from(value: SerialSum) -> Result<Self, Self::Error> {
let SerialSum {
tag,
values,
sum_type,
} = value;

let sum_type = if let Some(sum_type) = sum_type {
sum_type
} else {
if tag != 0 {
return Err("Sum type must be provided if tag is not 0");
}
SumType::new_tuple(values.iter().map(Value::get_type).collect_vec())
};

Ok(Self {
tag,
values,
sum_type,
})
}
}

impl From<Sum> for SerialSum {
fn from(value: Sum) -> Self {
Self {
tag: value.tag,
values: value.values,
sum_type: Some(value.sum_type),
}
}
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "v")]
/// A value that can be stored as a static constant. Representing core types and
Expand All @@ -118,25 +181,10 @@ pub enum Value {
/// A Hugr defining the function.
hugr: Box<Hugr>,
},
/// A tuple
Tuple {
/// Constant values in the tuple.
vs: Vec<Value>,
},
/// A Sum variant, with a tag indicating the index of the variant and its
/// value.
Sum {
/// The tag index of the variant.
tag: usize,
/// The value of the variant.
///
/// Sum variants are always a row of values, hence the Vec.
#[serde(rename = "vs")]
values: Vec<Value>,
/// The full type of the Sum, including the other variants.
#[serde(rename = "typ")]
sum_type: SumType,
},
#[serde(alias = "Tuple")]
Sum(Sum),
}

/// An opaque newtype around a [`Box<dyn CustomConst>`](CustomConst).
Expand Down Expand Up @@ -286,8 +334,7 @@ impl Value {
pub fn get_type(&self) -> Type {
match self {
Self::Extension { e } => e.get_type(),
Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::get_type).collect_vec()),
Self::Sum { sum_type, .. } => sum_type.clone().into(),
Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(),
Self::Function { hugr } => {
let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e));
Type::new_function(func_type)
Expand All @@ -305,18 +352,19 @@ impl Value {
) -> Result<Self, ConstTypeError> {
let values: Vec<Value> = items.into_iter().collect();
typ.check_type(tag, &values)?;
Ok(Self::Sum {
Ok(Self::Sum(Sum {
tag,
values,
sum_type: typ,
})
}))
}

/// Returns a tuple constant of constant values.
pub fn tuple(items: impl IntoIterator<Item = Value>) -> Self {
Self::Tuple {
vs: items.into_iter().collect(),
}
let vs = items.into_iter().collect_vec();
let tys = vs.iter().map(Self::get_type).collect_vec();

Self::sum(0, vs, SumType::new_tuple(tys)).expect("Tuple type is valid")
}

/// Returns a constant function defined by a Hugr.
Expand All @@ -334,7 +382,11 @@ impl Value {

/// Returns a constant unit type (empty Tuple).
pub const fn unit() -> Self {
Self::Tuple { vs: vec![] }
Self::Sum(Sum {
tag: 0,
values: vec![],
sum_type: SumType::Unit { size: 1 },
})
}

/// Returns a constant Sum over units. Used as branching values.
Expand Down Expand Up @@ -393,12 +445,17 @@ impl Value {
};
format!("const:function:[{}]", t)
}
Self::Tuple { vs: vals } => {
let names: Vec<_> = vals.iter().map(Value::name).collect();
format!("const:seq:{{{}}}", names.iter().join(", "))
}
Self::Sum { tag, values, .. } => {
format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
Self::Sum(Sum {
tag,
values,
sum_type,
}) => {
if sum_type.as_tuple().is_some() {
let names: Vec<_> = values.iter().map(Value::name).collect();
format!("const:seq:{{{}}}", names.iter().join(", "))
} else {
format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
}
}
}
.into()
Expand All @@ -409,8 +466,7 @@ impl Value {
match self {
Self::Extension { e } => e.extension_reqs().clone(),
Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run)
Self::Tuple { vs } => ExtensionSet::union_over(vs.iter().map(Value::extension_reqs)),
Self::Sum { values, .. } => {
Self::Sum(Sum { values, .. }) => {
ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs()))
}
}
Expand All @@ -424,22 +480,26 @@ impl Value {
mono_fn_type(hugr)?;
Ok(())
}
Self::Tuple { vs } => {
for v in vs {
v.validate()?;
}
Ok(())
}
Self::Sum {
Self::Sum(Sum {
tag,
values,
sum_type,
} => {
}) => {
sum_type.check_type(*tag, values)?;
Ok(())
}
}
}

/// If value is a sum with a single row variant, return the row.
pub fn as_tuple(&self) -> Option<&[Value]> {
match self {
Self::Sum(Sum {
values, sum_type, ..
}) if sum_type.as_tuple().is_some() => Some(values),
_ => None,
}
}
}

impl<T> From<T> for Value
Expand Down Expand Up @@ -662,7 +722,7 @@ mod test {
}

mod proptest {
use super::super::OpaqueValue;
use super::super::{OpaqueValue, Sum};
use crate::{
ops::{constant::CustomSerialized, Value},
std_extensions::arithmetic::int_types::ConstInt,
Expand Down Expand Up @@ -715,19 +775,19 @@ mod test {
3, // Each collection is up to 3 elements long
|element| {
prop_oneof![
vec(element.clone(), 0..3).prop_map(|vs| Self::Tuple { vs }),
vec(element.clone(), 0..3).prop_map(Self::tuple),
(
any::<usize>(),
vec(element.clone(), 0..3),
any_with::<SumType>(1.into()) // for speed: don't generate large sum types for now
)
.prop_map(
|(tag, values, sum_type)| {
Self::Sum {
Self::Sum(Sum {
tag,
values,
sum_type,
}
})
}
),
]
Expand Down
14 changes: 14 additions & 0 deletions hugr-core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ impl SumType {
Self::Unit { size }
}

/// New tuple (single row of variants)
pub fn new_tuple(types: impl Into<TypeRow>) -> Self {
Self::new([types.into()])
}

/// Report the tag'th variant, if it exists.
pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> {
match self {
Expand All @@ -185,6 +190,15 @@ impl SumType {
SumType::General { rows } => rows.len(),
}
}

/// Returns variant row if there is only one variant
pub fn as_tuple(&self) -> Option<&TypeRow> {
match self {
SumType::Unit { .. } => Some(Type::EMPTY_TYPEROW_REF),
SumType::General { rows } if rows.len() == 1 => Some(&rows[0]),
_ => None,
}
}
}

impl From<SumType> for Type {
Expand Down
2 changes: 1 addition & 1 deletion hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
}
OpType::UnpackTuple { .. } => {
let c = &consts.first()?.1;
let Value::Tuple { vs } = c else {
let Some(vs) = c.as_tuple() else {
panic!("This op always takes a Tuple input.");
};
out_row(vs.iter().cloned())
Expand Down

0 comments on commit df3bddd

Please sign in to comment.