Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: remove Value::Tuple #1255

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ digraph {
7 [shape=plain label=<<table border="1"><tr><td align="text" border="0" colspan="1">(7) Input</td></tr><tr><td port="out0" align="text" colspan="1" cellpadding="1" >0: usize</td></tr></table>>]
7:out0 -> 8:in1 [style=""]
8 [shape=plain label=<<table border="1"><tr><td port="in0" align="text" colspan="1" cellpadding="1" >0: []</td><td port="in1" align="text" colspan="1" cellpadding="1" >1: usize</td></tr><tr><td align="text" border="0" colspan="2">(8) Output</td></tr></table>>]
9 [shape=plain label=<<table border="1"><tr><td align="text" border="0" colspan="1">(9) const:sum:{tag:0, vals:[]}</td></tr><tr><td port="out0" align="text" colspan="1" cellpadding="1" >0: []</td></tr></table>>]
9 [shape=plain label=<<table border="1"><tr><td align="text" border="0" colspan="1">(9) const:seq:{}</td></tr><tr><td port="out0" align="text" colspan="1" cellpadding="1" >0: []</td></tr></table>>]
9:out0 -> 10:in0 [style=""]
10 [shape=plain label=<<table border="1"><tr><td port="in0" align="text" colspan="1" cellpadding="1" >0: []</td></tr><tr><td align="text" border="0" colspan="1">(10) LoadConstant</td></tr><tr><td port="out0" align="text" colspan="1" cellpadding="1" >0: []</td></tr></table>>]
10:out0 -> 8:in0 [style=""]
Expand All @@ -44,3 +44,4 @@ hier8 [shape=plain label="8"]
hier9 [shape=plain label="9"]
hier10 [shape=plain label="10"]
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ graph LR
7["(7) Input"]
7--"0:1<br>usize"-->8
8["(8) Output"]
9["(9) const:sum:{tag:0, vals:[]}"]
9["(9) const:seq:{}"]
9--"0:0<br>[]"-->10
10["(10) LoadConstant"]
10--"0:0<br>[]"-->8
end
6-."0:0".->1
end

227 changes: 182 additions & 45 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,76 @@ impl AsRef<Value> for Const {
}
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
struct SerialSum {
#[serde(default)]
tag: usize,
#[serde(rename = "vs")]
values: Vec<Value>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could omit this when it is empty. Value::unary_unit_sum would then omit all fields! Relevant for values of any unary sum type. Perhaps this would be better as a future improvement.

#[serde(default, rename = "typ")]
sum_type: Option<SumType>,
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(try_from = "SerialSum")]
#[serde(into = "SerialSum")]
/// A Sum variant, with a tag indicating the index of the variant and its
/// value.
pub struct Sum {
/// The tag index of the variant.
pub tag: usize,
/// The value of the variant.
///
/// Sum variants are always a row of values, hence the Vec.
pub values: Vec<Value>,
/// The full type of the Sum, including the other variants.
pub sum_type: SumType,
}

impl Sum {
/// If value is a sum with a single row variant, return the row.
pub fn as_tuple(&self) -> Option<&[Value]> {
self.sum_type.as_tuple().map(|_| self.values.as_ref())
}
}

impl TryFrom<SerialSum> for Sum {
type Error = &'static str;

fn try_from(value: SerialSum) -> Result<Self, Self::Error> {
let SerialSum {
tag,
values,
sum_type,
} = value;

let sum_type = if let Some(sum_type) = sum_type {
sum_type
} else {
if tag != 0 {
return Err("Sum type must be provided if tag is not 0");
}
SumType::new_tuple(values.iter().map(Value::get_type).collect_vec())
};

Ok(Self {
tag,
values,
sum_type,
})
}
}

impl From<Sum> for SerialSum {
fn from(value: Sum) -> Self {
Self {
tag: value.tag,
values: value.values,
sum_type: Some(value.sum_type),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sum_type: Some(value.sum_type),
(value.sum_type.num_variants() > 1).then_some(value.sum_type),

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is schema breaking - Sums are not allowed to have null sum_type

(we won't be generating serialised "Tuple" from the rust any longer, I think to do that we would need to implement Serialize for Value manually, which is annoying).

I'm happy with this compromise, but to really keep the "serialised tuple" optimisation I think we would want to rethink the serialised schema

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that is annoying.

}
}
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "v")]
/// A value that can be stored as a static constant. Representing core types and
Expand All @@ -118,25 +188,10 @@ pub enum Value {
/// A Hugr defining the function.
hugr: Box<Hugr>,
},
/// A tuple
Tuple {
/// Constant values in the tuple.
vs: Vec<Value>,
},
/// A Sum variant, with a tag indicating the index of the variant and its
/// value.
Sum {
/// The tag index of the variant.
tag: usize,
/// The value of the variant.
///
/// Sum variants are always a row of values, hence the Vec.
#[serde(rename = "vs")]
values: Vec<Value>,
/// The full type of the Sum, including the other variants.
#[serde(rename = "typ")]
sum_type: SumType,
},
#[serde(alias = "Tuple")]
Sum(Sum),
}

/// An opaque newtype around a [`Box<dyn CustomConst>`](CustomConst).
Expand Down Expand Up @@ -286,8 +341,7 @@ impl Value {
pub fn get_type(&self) -> Type {
match self {
Self::Extension { e } => e.get_type(),
Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::get_type).collect_vec()),
Self::Sum { sum_type, .. } => sum_type.clone().into(),
Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(),
Self::Function { hugr } => {
let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e));
Type::new_function(func_type)
Expand All @@ -305,18 +359,19 @@ impl Value {
) -> Result<Self, ConstTypeError> {
let values: Vec<Value> = items.into_iter().collect();
typ.check_type(tag, &values)?;
Ok(Self::Sum {
Ok(Self::Sum(Sum {
tag,
values,
sum_type: typ,
})
}))
}

/// Returns a tuple constant of constant values.
pub fn tuple(items: impl IntoIterator<Item = Value>) -> Self {
Self::Tuple {
vs: items.into_iter().collect(),
}
let vs = items.into_iter().collect_vec();
let tys = vs.iter().map(Self::get_type).collect_vec();

Self::sum(0, vs, SumType::new_tuple(tys)).expect("Tuple type is valid")
}

/// Returns a constant function defined by a Hugr.
Expand All @@ -334,7 +389,11 @@ impl Value {

/// Returns a constant unit type (empty Tuple).
pub const fn unit() -> Self {
Self::Tuple { vs: vec![] }
Self::Sum(Sum {
tag: 0,
values: vec![],
sum_type: SumType::Unit { size: 1 },
})
}

/// Returns a constant Sum over units. Used as branching values.
Expand Down Expand Up @@ -393,12 +452,17 @@ impl Value {
};
format!("const:function:[{}]", t)
}
Self::Tuple { vs: vals } => {
let names: Vec<_> = vals.iter().map(Value::name).collect();
format!("const:seq:{{{}}}", names.iter().join(", "))
}
Self::Sum { tag, values, .. } => {
format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
Self::Sum(Sum {
tag,
values,
sum_type,
}) => {
if sum_type.as_tuple().is_some() {
let names: Vec<_> = values.iter().map(Value::name).collect();
format!("const:seq:{{{}}}", names.iter().join(", "))
} else {
format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
}
}
}
.into()
Expand All @@ -409,8 +473,7 @@ impl Value {
match self {
Self::Extension { e } => e.extension_reqs().clone(),
Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run)
Self::Tuple { vs } => ExtensionSet::union_over(vs.iter().map(Value::extension_reqs)),
Self::Sum { values, .. } => {
Self::Sum(Sum { values, .. }) => {
ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs()))
}
}
Expand All @@ -424,22 +487,25 @@ impl Value {
mono_fn_type(hugr)?;
Ok(())
}
Self::Tuple { vs } => {
for v in vs {
v.validate()?;
}
Ok(())
}
Self::Sum {
Self::Sum(Sum {
tag,
values,
sum_type,
} => {
}) => {
sum_type.check_type(*tag, values)?;
Ok(())
}
}
}

/// If value is a sum with a single row variant, return the row.
pub fn as_tuple(&self) -> Option<&[Value]> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should add this method to Sum as well, and delegate to that here.

if let Self::Sum(sum) = self {
sum.as_tuple()
} else {
None
}
}
}

impl<T> From<T> for Value
Expand Down Expand Up @@ -662,7 +728,7 @@ mod test {
}

mod proptest {
use super::super::OpaqueValue;
use super::super::{OpaqueValue, Sum};
use crate::{
ops::{constant::CustomSerialized, Value},
std_extensions::arithmetic::int_types::ConstInt,
Expand Down Expand Up @@ -715,19 +781,19 @@ mod test {
3, // Each collection is up to 3 elements long
|element| {
prop_oneof![
vec(element.clone(), 0..3).prop_map(|vs| Self::Tuple { vs }),
vec(element.clone(), 0..3).prop_map(Self::tuple),
(
any::<usize>(),
vec(element.clone(), 0..3),
any_with::<SumType>(1.into()) // for speed: don't generate large sum types for now
)
.prop_map(
|(tag, values, sum_type)| {
Self::Sum {
Self::Sum(Sum {
tag,
values,
sum_type,
}
})
}
),
]
Expand All @@ -737,4 +803,75 @@ mod test {
}
}
}

#[test]
fn test_tuple_deserialize() {
let json = r#"
{
"v": "Tuple",
"vs": [
{
"v": "Sum",
"tag": 0,
"typ": {
"t": "Sum",
"s": "Unit",
"size": 1
},
"vs": []
},
{
"v": "Sum",
"tag": 1,
"typ": {
"t": "Sum",
"s": "General",
"rows": [
[
{
"t": "Sum",
"s": "Unit",
"size": 1
}
],
[
{
"t": "Sum",
"s": "Unit",
"size": 2
}
]
]
},
"vs": [
{
"v": "Sum",
"tag": 1,
"typ": {
"t": "Sum",
"s": "Unit",
"size": 2
},
"vs": []
}
]
}
]
}
"#;

let v: Value = serde_json::from_str(json).unwrap();
assert_eq!(
v,
Value::tuple([
Value::unit(),
Value::sum(
1,
[Value::true_val()],
SumType::new([vec![Type::UNIT], vec![Value::true_val().get_type()]]),
)
.unwrap()
])
);
}
}
14 changes: 14 additions & 0 deletions hugr-core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ impl SumType {
Self::Unit { size }
}

/// New tuple (single row of variants)
pub fn new_tuple(types: impl Into<TypeRow>) -> Self {
Self::new([types.into()])
}

/// Report the tag'th variant, if it exists.
pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> {
match self {
Expand All @@ -185,6 +190,15 @@ impl SumType {
SumType::General { rows } => rows.len(),
}
}

/// Returns variant row if there is only one variant
pub fn as_tuple(&self) -> Option<&TypeRow> {
match self {
SumType::Unit { size } if *size == 1 => Some(Type::EMPTY_TYPEROW_REF),
SumType::General { rows } if rows.len() == 1 => Some(&rows[0]),
_ => None,
}
}
}

impl From<SumType> for Type {
Expand Down
2 changes: 1 addition & 1 deletion hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
}
OpType::UnpackTuple { .. } => {
let c = &consts.first()?.1;
let Value::Tuple { vs } = c else {
let Some(vs) = c.as_tuple() else {
panic!("This op always takes a Tuple input.");
};
out_row(vs.iter().cloned())
Expand Down
Loading