Skip to content

Commit

Permalink
Rename {Const,Value}::const_type -> get_type. Add some useful functio…
Browse files Browse the repository at this point in the history
…ns for `ops::constant::ExtensionValue`
  • Loading branch information
doug-q committed May 8, 2024
1 parent e5fd315 commit f77f5ed
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 65 deletions.
4 changes: 2 additions & 2 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
/// Generate a graph that loads and outputs `consts` in order, validating
/// against `reg`.
fn const_graph(consts: Vec<Value>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Value::const_type).collect_vec();
let const_types = consts.iter().map(Value::get_type).collect_vec();
let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap();

let outputs = consts
Expand Down Expand Up @@ -337,7 +337,7 @@ mod test {
let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into();
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![list.const_type().clone()],
vec![list.get_type().clone()],
))
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ pub trait Dataflow: Container {
let load_n = self
.add_dataflow_op(
ops::LoadConstant {
datatype: op.const_type().clone(),
datatype: op.get_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, OutgoingPort::from(0))],
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ mod test {
let const_wire = loop_b.add_load_const(Value::true_val());
let lift_node = loop_b.add_dataflow_op(
ops::Lift {
type_row: vec![const_val.const_type().clone()].into(),
type_row: vec![const_val.get_type().clone()].into(),
new_extension: PRELUDE_ID,
},
[const_wire],
Expand Down
120 changes: 71 additions & 49 deletions hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::extension::ExtensionSet;
use crate::types::{CustomType, EdgeKind, FunctionType, SumType, SumTypeError, Type};
use crate::{Hugr, HugrView};

use delegate::delegate;
use itertools::Itertools;
use smol_str::SmolStr;
use thiserror::Error;
Expand Down Expand Up @@ -35,9 +36,11 @@ impl Const {
&self.value
}

/// Returns a reference to the type of this constant.
pub fn const_type(&self) -> Type {
self.value.const_type()
delegate! {
to self.value {
/// Returns the type of this constant.
pub fn get_type(&self) -> Type;
}
}
}

Expand All @@ -47,6 +50,34 @@ impl From<Value> for Const {
}
}

impl NamedOp for Const {
fn name(&self) -> OpName {
self.value().name()
}
}

impl StaticTag for Const {
const TAG: OpTag = OpTag::Const;
}

impl OpTrait for Const {
fn description(&self) -> &str {
"Constant value"
}

fn extension_delta(&self) -> ExtensionSet {
self.value().extension_reqs()
}

fn tag(&self) -> OpTag {
<Self as StaticTag>::TAG
}

fn static_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::Const(self.get_type()))
}
}

impl From<Const> for Value {
fn from(konst: Const) -> Self {
konst.value
Expand Down Expand Up @@ -96,16 +127,36 @@ pub enum Value {
},
}

/// Boxed [`CustomConst`] trait object.
/// An opaque newtype awround a `Box<dyn CustomConst>`.
///
/// Use [`Value::extension`] to create a new variant of this type.
///
/// This is required to avoid <https://github.com/rust-lang/rust/issues/78808> in
/// [`Value::Extension`], while implementing a transparent encoding into a
/// `CustomConst`.
/// This will be the serialisation barrier that ensures all implementors of
/// [`CustomConst`] are serialised through [`CustomSerialized`].
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(transparent)]
pub struct ExtensionValue(pub(super) Box<dyn CustomConst>);
pub struct ExtensionValue(Box<dyn CustomConst>);

impl ExtensionValue {
/// Create a new [`ExtensionValue`] from any [`CustomConst`].
pub fn new(cc: impl CustomConst) -> Self {
Self(Box::new(cc))
}

/// Returns a reference to the internal [`CustomConst`].
pub fn value(&self) -> &dyn CustomConst {
self.0.as_ref()
}

delegate! {
to self.0 {
/// Returns the type of the internal [`CustomConst`].
pub fn get_type(&self) -> Type;
/// An identifier of the internal [`CustomConst`].
pub fn name(&self) -> ValueName;
/// The extension(s) defining the internal [`CustomConst`].
pub fn extension_reqs(&self) -> ExtensionSet;
}
}
}

impl PartialEq for ExtensionValue {
fn eq(&self, other: &Self) -> bool {
Expand Down Expand Up @@ -166,11 +217,11 @@ fn mono_fn_type(h: &Hugr) -> Result<FunctionType, ConstTypeError> {
}

impl Value {
/// Returns a reference to the type of this [`Value`].
pub fn const_type(&self) -> Type {
/// Returns the type of this [`Value`].
pub fn get_type(&self) -> Type {
match self {
Self::Extension { e } => e.0.get_type(),
Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::const_type).collect_vec()),
Self::Extension { e } => e.get_type(),
Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::get_type).collect_vec()),
Self::Sum { sum_type, .. } => sum_type.clone().into(),
Self::Function { hugr } => {
let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e));
Expand Down Expand Up @@ -268,7 +319,7 @@ impl Value {

fn name(&self) -> OpName {
match self {
Self::Extension { e } => format!("const:custom:{}", e.0.name()),
Self::Extension { e } => format!("const:custom:{}", e.name()),
Self::Function { hugr: h } => {
let Some(t) = h.get_function_type() else {
panic!("HUGR root node isn't a valid function parent.");
Expand All @@ -289,7 +340,7 @@ impl Value {
/// The extensions required by a [`Value`]
pub fn extension_reqs(&self) -> ExtensionSet {
match self {
Self::Extension { e } => e.0.extension_reqs().clone(),
Self::Extension { e } => e.extension_reqs().clone(),
Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run)
Self::Tuple { vs } => ExtensionSet::union_over(vs.iter().map(Value::extension_reqs)),
Self::Sum { values, .. } => {
Expand All @@ -299,35 +350,6 @@ impl Value {
}
}

impl NamedOp for Const {
fn name(&self) -> OpName {
self.value().name()
}
}

impl StaticTag for Const {
const TAG: OpTag = OpTag::Const;
}
impl OpTrait for Const {
fn description(&self) -> &str {
"Constant value"
}

fn extension_delta(&self) -> ExtensionSet {
self.value().extension_reqs()
}

fn tag(&self) -> OpTag {
<Self as StaticTag>::TAG
}

fn static_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::Const(self.const_type()))
}
}

// [KnownTypeConst] is guaranteed to be the right type, so can be constructed
// without initial type check.
impl<T> From<T> for Value
where
T: CustomConst,
Expand Down Expand Up @@ -484,7 +506,7 @@ mod test {
crate::extension::prelude::BOOL_T
]));

assert_eq!(v.const_type(), correct_type);
assert_eq!(v.get_type(), correct_type);
assert!(v.name().starts_with("const:function:"))
}

Expand All @@ -508,7 +530,7 @@ mod test {
#[case] expected_type: Type,
#[case] name_prefix: &str,
) {
assert_eq!(const_value.const_type(), expected_type);
assert_eq!(const_value.get_type(), expected_type);
let name = const_value.name();
assert!(
name.starts_with(name_prefix),
Expand Down Expand Up @@ -541,10 +563,10 @@ mod test {
.into();
let classic_t = Type::new_extension(typ_int.clone());
assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq);
assert_eq!(yaml_const.const_type(), classic_t);
assert_eq!(yaml_const.get_type(), classic_t);

let typ_qb = CustomType::new("my_type", vec![], ex_id, TypeBound::Eq);
let t = Type::new_extension(typ_qb.clone());
assert_ne!(yaml_const.const_type(), t);
assert_ne!(yaml_const.get_type(), t);
}
}
16 changes: 8 additions & 8 deletions hugr/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub trait CustomConst:
/// [USize]: crate::extension::prelude::USIZE_T
fn extension_reqs(&self) -> ExtensionSet;

/// Check the value is a valid instance of the provided type.
/// Check the value.
fn validate(&self) -> Result<(), CustomCheckFailure> {
Ok(())
}
Expand All @@ -48,10 +48,16 @@ pub trait CustomConst:
false
}

/// report the type
/// Report the type.
fn get_type(&self) -> Type;
}

impl PartialEq for dyn CustomConst {
fn eq(&self, other: &Self) -> bool {
(*self).equal_consts(other)
}
}

/// Const equality for types that have PartialEq
pub fn downcast_equal_consts<T: CustomConst + PartialEq>(
constant: &T,
Expand Down Expand Up @@ -112,9 +118,3 @@ impl CustomConst for CustomSerialized {
self.typ.clone()
}
}

impl PartialEq for dyn CustomConst {
fn eq(&self, other: &Self) -> bool {
(*self).equal_consts(other)
}
}
2 changes: 1 addition & 1 deletion hugr/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl CustomConst for ListValue {

// check all values are instances of the element type
for v in &self.0 {
if v.const_type() != *ty {
if v.get_type() != *ty {
return Err(error());
}
}
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ pub(crate) mod test {
let true_val = r.get_value(&TRUE_NAME).unwrap();

for v in [false_val, true_val] {
let simpl = v.typed_value().const_type();
let simpl = v.typed_value().get_type();
assert_eq!(simpl, BOOL_T);
}
}
Expand Down
4 changes: 2 additions & 2 deletions hugr/src/types/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::ops::Value;
#[non_exhaustive]
pub enum SumTypeError {
/// The type of the variant doesn't match the type of the value.
#[error("Expected type {expected} for element {index} of variant #{tag}, but found {}", .found.const_type())]
#[error("Expected type {expected} for element {index} of variant #{tag}, but found {}", .found.get_type())]
InvalidValueType {
/// Tag of the variant.
tag: usize,
Expand Down Expand Up @@ -70,7 +70,7 @@ impl super::SumType {
}

for (index, (t, v)) in itertools::zip_eq(variant.iter(), val.iter()).enumerate() {
if v.const_type() != *t {
if v.get_type() != *t {
Err(SumTypeError::InvalidValueType {
tag,
index,
Expand Down

0 comments on commit f77f5ed

Please sign in to comment.