Skip to content

Commit

Permalink
refactor!: Make Either::Right the "success" case (#1489)
Browse files Browse the repository at this point in the history
Similarly for options, which now have type `[] + [elems...]`.
Adds `const_ok`/`_fail` aliases to `const_right`/`_left`.

Closes #1487.

Modifies the old folding tests that hard-coded the sum tags.

BREAKING CHANGE: Binary sums representing fallible values now use tag
`1` for the successful variant
  • Loading branch information
aborgna-q authored Aug 30, 2024
1 parent aca403a commit 8caa572
Show file tree
Hide file tree
Showing 16 changed files with 344 additions and 341 deletions.
70 changes: 56 additions & 14 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,24 @@ pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE);
/// The string name of the error type.
pub const ERROR_TYPE_NAME: TypeName = TypeName::new_inline("error");

/// Return a Sum type with the first variant as the given type and the second an Error.
/// Return a Sum type with the second variant as the given type and the first an Error.
pub fn sum_with_error(ty: impl Into<TypeRowRV>) -> SumType {
either_type(ty, ERROR_TYPE)
either_type(ERROR_TYPE, ty)
}

/// An optional type, i.e. a Sum type with the first variant as the given type and the second as an empty tuple.
/// An optional type, i.e. a Sum type with the second variant as the given type and the first as an empty tuple.
#[inline]
pub fn option_type(ty: impl Into<TypeRowRV>) -> SumType {
either_type(ty, TypeRow::new())
either_type(TypeRow::new(), ty)
}

/// An "either" type, i.e. a Sum type with a "left" and a "right" variant.
///
/// When used as a fallible value, the "left" variant represents a successful computation,
/// and the "right" variant represents a failure.
/// When used as a fallible value, the "right" variant represents a successful computation,
/// and the "left" variant represents a failure.
#[inline]
pub fn either_type(ty_ok: impl Into<TypeRowRV>, ty_err: impl Into<TypeRowRV>) -> SumType {
SumType::new([ty_ok.into(), ty_err.into()])
pub fn either_type(ty_left: impl Into<TypeRowRV>, ty_right: impl Into<TypeRowRV>) -> SumType {
SumType::new([ty_left.into(), ty_right.into()])
}

/// A constant optional value with a given value.
Expand All @@ -279,19 +279,19 @@ pub fn const_some(value: Value) -> Value {
///
/// See [option_type].
pub fn const_some_tuple(values: impl IntoIterator<Item = Value>) -> Value {
const_left_tuple(values, TypeRow::new())
const_right_tuple(TypeRow::new(), values)
}

/// A constant optional value with no value.
///
/// See [option_type].
pub fn const_none(ty: impl Into<TypeRowRV>) -> Value {
const_right_tuple(ty, [])
const_left_tuple([], ty)
}

/// A constant Either value with a left variant.
///
/// In fallible computations, this represents a successful result.
/// In fallible computations, this represents a failure.
///
/// See [either_type].
pub fn const_left(value: Value, ty_right: impl Into<TypeRowRV>) -> Value {
Expand All @@ -300,7 +300,7 @@ pub fn const_left(value: Value, ty_right: impl Into<TypeRowRV>) -> Value {

/// A constant Either value with a row of left values.
///
/// In fallible computations, this represents a successful result.
/// In fallible computations, this represents a failure.
///
/// See [either_type].
pub fn const_left_tuple(
Expand All @@ -319,7 +319,7 @@ pub fn const_left_tuple(

/// A constant Either value with a right variant.
///
/// In fallible computations, this represents a failure.
/// In fallible computations, this represents a successful result.
///
/// See [either_type].
pub fn const_right(ty_left: impl Into<TypeRowRV>, value: Value) -> Value {
Expand All @@ -328,7 +328,7 @@ pub fn const_right(ty_left: impl Into<TypeRowRV>, value: Value) -> Value {

/// A constant Either value with a row of right values.
///
/// In fallible computations, this represents a failure.
/// In fallible computations, this represents a successful result.
///
/// See [either_type].
pub fn const_right_tuple(
Expand All @@ -345,6 +345,40 @@ pub fn const_right_tuple(
Value::sum(1, values, typ).unwrap()
}

/// A constant Either value with a success variant.
///
/// Alias for [const_right].
pub fn const_ok(value: Value, ty_fail: impl Into<TypeRowRV>) -> Value {
const_right(ty_fail, value)
}

/// A constant Either with a row of success values.
///
/// Alias for [const_right_tuple].
pub fn const_ok_tuple(
values: impl IntoIterator<Item = Value>,
ty_fail: impl Into<TypeRowRV>,
) -> Value {
const_right_tuple(ty_fail, values)
}

/// A constant Either value with a failure variant.
///
/// Alias for [const_left].
pub fn const_fail(value: Value, ty_ok: impl Into<TypeRowRV>) -> Value {
const_left(value, ty_ok)
}

/// A constant Either with a row of failure values.
///
/// Alias for [const_left_tuple].
pub fn const_fail_tuple(
values: impl IntoIterator<Item = Value>,
ty_ok: impl Into<TypeRowRV>,
) -> Value {
const_left_tuple(values, ty_ok)
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
pub struct ConstUsize(u64);
Expand Down Expand Up @@ -397,6 +431,14 @@ impl ConstError {
message: message.to_string(),
}
}

/// Returns an "either" value with a failure variant.
///
/// args:
/// ty_ok: The type of the success variant.
pub fn as_either(self, ty_ok: impl Into<TypeRowRV>) -> Value {
const_fail(self.into(), ty_ok)
}
}

#[typetag::serde]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::INT_TYPES;
use crate::{
extension::{
prelude::{sum_with_error, ConstError},
prelude::{const_ok, ConstError, ERROR_TYPE},
ConstFold, ConstFoldResult, OpDef,
},
ops,
Expand Down Expand Up @@ -40,21 +40,19 @@ fn fold_trunc(
};
let log_width = get_log_width(arg).ok()?;
let int_type = INT_TYPES[log_width as usize].to_owned();
let sum_type = sum_with_error(int_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Can't truncate non-finite float".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_type.clone())
};
let out_const: ops::Value = if !f.is_finite() {
err_value()
} else {
let cv = convert(f, log_width);
if let Ok(cv) = cv {
Value::sum(0, [cv], sum_type).unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
const_ok(cv, ERROR_TYPE)
} else {
err_value()
}
Expand Down
58 changes: 23 additions & 35 deletions hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
Value,
},
std_extensions::arithmetic::int_types::{get_log_width, ConstInt, INT_TYPES},
types::{SumType, Type, TypeArg},
types::{Type, TypeArg},
IncomingPort,
};

Expand Down Expand Up @@ -132,9 +132,9 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
};
let n0val: u64 = n0.value_u();
let out_const: Value = if n0val >> (1 << logwidth1) != 0 {
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
mk_out_const(0, Ok(INARROW_ERROR_VALUE.clone()))
} else {
mk_out_const(0, ConstInt::new_u(logwidth1, n0val).map(Into::into))
mk_out_const(1, ConstInt::new_u(logwidth1, n0val).map(Into::into))
};
Some(vec![(0.into(), out_const)])
},
Expand All @@ -160,9 +160,9 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let n0val: i64 = n0.value_s();
let ub = 1i64 << ((1 << logwidth1) - 1);
let out_const: Value = if n0val >= ub || n0val < -ub {
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
mk_out_const(0, Ok(INARROW_ERROR_VALUE.clone()))
} else {
mk_out_const(0, ConstInt::new_s(logwidth1, n0val).map(Into::into))
mk_out_const(1, ConstInt::new_s(logwidth1, n0val).map(Into::into))
};
Some(vec![(0.into(), out_const)])
},
Expand Down Expand Up @@ -631,14 +631,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let q_type = INT_TYPES[logwidth0 as usize].to_owned();
let r_type = q_type.clone();
let qr_type: Type = Type::new_tuple(vec![q_type, r_type]);
let sum_type: SumType = sum_with_error(qr_type);
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(qr_type)
};
let nval = n.value_u();
let mval = m.value_u();
Expand Down Expand Up @@ -694,14 +692,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let q_type = INT_TYPES[logwidth0 as usize].to_owned();
let r_type = INT_TYPES[logwidth0 as usize].to_owned();
let qr_type: Type = Type::new_tuple(vec![q_type, r_type]);
let sum_type: SumType = sum_with_error(qr_type);
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(qr_type)
};
let nval = n.value_s();
let mval = m.value_u();
Expand Down Expand Up @@ -754,14 +750,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
None
} else {
let int_out_type = INT_TYPES[logwidth0 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_out_type.clone())
};
let nval = n.value_u();
let mval = m.value_u();
Expand Down Expand Up @@ -808,14 +802,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
None
} else {
let int_out_type = INT_TYPES[logwidth0 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_out_type.clone())
};
let nval = n.value_u();
let mval = m.value_u();
Expand Down Expand Up @@ -862,14 +854,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
None
} else {
let int_out_type = INT_TYPES[logwidth0 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_out_type.clone())
};
let nval = n.value_s();
let mval = m.value_u();
Expand Down Expand Up @@ -918,14 +908,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
None
} else {
let int_out_type = INT_TYPES[logwidth0 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_out_type.clone())
};
let nval = n.value_s();
let mval = m.value_u();
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ mod test {
use rstest::rstest;

use crate::extension::prelude::{
const_left_tuple, const_none, const_right_tuple, const_some_tuple,
const_fail_tuple, const_none, const_ok_tuple, const_some_tuple,
};
use crate::ops::OpTrait;
use crate::PortIndex;
Expand Down Expand Up @@ -467,11 +467,11 @@ mod test {
TestVal::None(tr) => const_none(tr.clone()),
TestVal::Ok(l, tr) => {
let elems = l.iter().map(TestVal::to_value);
const_left_tuple(elems, tr.clone())
const_ok_tuple(elems, tr.clone())
}
TestVal::Err(tr, l) => {
let elems = l.iter().map(TestVal::to_value);
const_right_tuple(tr.clone(), elems)
const_fail_tuple(elems, tr.clone())
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions hugr-core/src/std_extensions/collections/list_fold.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Folding definitions for list operations.
use crate::extension::prelude::{
const_left, const_left_tuple, const_none, const_right, const_some, ConstUsize,
const_fail, const_none, const_ok, const_ok_tuple, const_some, ConstUsize,
};
use crate::extension::{ConstFold, ConstFoldResult, OpDef};
use crate::ops::Value;
Expand Down Expand Up @@ -96,9 +96,9 @@ impl ConstFold for SetFold {
let res_elem: Value = match list.0.get_mut(idx) {
Some(old_elem) => {
std::mem::swap(old_elem, &mut elem);
const_left(elem, list.1.clone())
const_ok(elem, list.1.clone())
}
None => const_right(list.1.clone(), elem),
None => const_fail(elem, list.1.clone()),
};
Some(vec![(0.into(), list.into()), (1.into(), res_elem)])
}
Expand All @@ -118,9 +118,9 @@ impl ConstFold for InsertFold {
let elem = elem.clone();
let res_elem: Value = if list.0.len() > idx {
list.0.insert(idx, elem);
const_left_tuple([], list.1.clone())
const_ok_tuple([], list.1.clone())
} else {
const_right(Type::UNIT, elem)
const_fail(elem, Type::UNIT)
};
Some(vec![(0.into(), list.into()), (1.into(), res_elem)])
}
Expand Down
Loading

0 comments on commit 8caa572

Please sign in to comment.