Skip to content

Commit

Permalink
use named fields
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Dec 17, 2024
1 parent e2d454e commit b4ee9a9
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions hugr-core/src/std_extensions/collections/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,40 +36,49 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
/// Statically sized array of values, all of the same type.
pub struct ArrayValue(Vec<Value>, Type);
pub struct ArrayValue {
values: Vec<Value>,
typ: Type,
}

impl ArrayValue {
/// Create a new [CustomConst] for an array of values of type `typ`.
/// That all values are of type `typ` is not checked here.
pub fn new(typ: Type, contents: impl IntoIterator<Item = Value>) -> Self {
Self(contents.into_iter().collect_vec(), typ)
Self {
values: contents.into_iter().collect_vec(),
typ,
}
}

/// Create a new [CustomConst] for an empty array of values of type `typ`.
pub fn new_empty(typ: Type) -> Self {
Self(vec![], typ)
Self {
values: vec![],
typ,
}
}

/// Returns the type of the `[ArrayValue]` as a `[CustomType]`.`
pub fn custom_type(&self) -> CustomType {
array_custom_type(self.0.len() as u64, self.1.clone())
array_custom_type(self.values.len() as u64, self.typ.clone())
}

/// Returns the type of values inside the `[ArrayValue]`.
pub fn get_element_type(&self) -> &Type {
&self.1
&self.typ
}

/// Returns the values contained inside the `[ArrayValue]`.
pub fn get_contents(&self) -> &[Value] {
&self.0
&self.values
}
}

impl TryHash for ArrayValue {
fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
maybe_hash_values(&self.0, &mut st) && {
self.1.hash(&mut st);
maybe_hash_values(&self.values, &mut st) && {
self.typ.hash(&mut st);
true
}
}
Expand Down Expand Up @@ -100,7 +109,11 @@ impl CustomConst for ArrayValue {

// constant can only hold classic type.
let ty = match typ.args() {
[TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if *n as usize == self.0.len() => ty,
[TypeArg::BoundedNat { n }, TypeArg::Type { ty }]
if *n as usize == self.values.len() =>
{
ty
}
_ => {
return Err(CustomCheckFailure::Message(format!(
"Invalid array type arguments: {:?}",
Expand All @@ -110,7 +123,7 @@ impl CustomConst for ArrayValue {
};

// check all values are instances of the element type
for v in &self.0 {
for v in &self.values {
if v.get_type() != *ty {
return Err(CustomCheckFailure::Message(format!(
"Array element {v:?} is not of expected type {ty}"
Expand All @@ -126,18 +139,18 @@ impl CustomConst for ArrayValue {
}

fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs))
ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs))
.union(EXTENSION_ID.into())
}

fn update_extensions(
&mut self,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
for val in &mut self.0 {
for val in &mut self.values {
resolve_value_extensions(val, extensions)?;
}
resolve_type_extensions(&mut self.1, extensions)
resolve_type_extensions(&mut self.typ, extensions)
}
}

Expand Down Expand Up @@ -239,11 +252,17 @@ mod test {

#[test]
fn test_array_value() {
let array_value = ArrayValue(vec![ConstUsize::new(3).into()], usize_t());
let array_value = ArrayValue {
values: vec![ConstUsize::new(3).into()],
typ: usize_t(),
};

array_value.validate().unwrap();

let wrong_array_value = ArrayValue(vec![ConstF64::new(1.2).into()], usize_t());
let wrong_array_value = ArrayValue {
values: vec![ConstF64::new(1.2).into()],
typ: usize_t(),
};
assert!(wrong_array_value.validate().is_err());
}
}

0 comments on commit b4ee9a9

Please sign in to comment.