diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 8fc1fe6fe..c1ccfe57e 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -36,40 +36,49 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] /// Statically sized array of values, all of the same type. -pub struct ArrayValue(Vec, Type); +pub struct ArrayValue { + values: Vec, + typ: Type, +} impl ArrayValue { /// Create a new [CustomConst] for an array of values of type `typ`. /// That all values are of type `typ` is not checked here. pub fn new(typ: Type, contents: impl IntoIterator) -> Self { - Self(contents.into_iter().collect_vec(), typ) + Self { + values: contents.into_iter().collect_vec(), + typ, + } } /// Create a new [CustomConst] for an empty array of values of type `typ`. pub fn new_empty(typ: Type) -> Self { - Self(vec![], typ) + Self { + values: vec![], + typ, + } } /// Returns the type of the `[ArrayValue]` as a `[CustomType]`.` pub fn custom_type(&self) -> CustomType { - array_custom_type(self.0.len() as u64, self.1.clone()) + array_custom_type(self.values.len() as u64, self.typ.clone()) } /// Returns the type of values inside the `[ArrayValue]`. pub fn get_element_type(&self) -> &Type { - &self.1 + &self.typ } /// Returns the values contained inside the `[ArrayValue]`. pub fn get_contents(&self) -> &[Value] { - &self.0 + &self.values } } impl TryHash for ArrayValue { fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { - maybe_hash_values(&self.0, &mut st) && { - self.1.hash(&mut st); + maybe_hash_values(&self.values, &mut st) && { + self.typ.hash(&mut st); true } } @@ -100,7 +109,11 @@ impl CustomConst for ArrayValue { // constant can only hold classic type. let ty = match typ.args() { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if *n as usize == self.0.len() => ty, + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] + if *n as usize == self.values.len() => + { + ty + } _ => { return Err(CustomCheckFailure::Message(format!( "Invalid array type arguments: {:?}", @@ -110,7 +123,7 @@ impl CustomConst for ArrayValue { }; // check all values are instances of the element type - for v in &self.0 { + for v in &self.values { if v.get_type() != *ty { return Err(CustomCheckFailure::Message(format!( "Array element {v:?} is not of expected type {ty}" @@ -126,7 +139,7 @@ impl CustomConst for ArrayValue { } fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) + ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs)) .union(EXTENSION_ID.into()) } @@ -134,10 +147,10 @@ impl CustomConst for ArrayValue { &mut self, extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - for val in &mut self.0 { + for val in &mut self.values { resolve_value_extensions(val, extensions)?; } - resolve_type_extensions(&mut self.1, extensions) + resolve_type_extensions(&mut self.typ, extensions) } } @@ -239,11 +252,17 @@ mod test { #[test] fn test_array_value() { - let array_value = ArrayValue(vec![ConstUsize::new(3).into()], usize_t()); + let array_value = ArrayValue { + values: vec![ConstUsize::new(3).into()], + typ: usize_t(), + }; array_value.validate().unwrap(); - let wrong_array_value = ArrayValue(vec![ConstF64::new(1.2).into()], usize_t()); + let wrong_array_value = ArrayValue { + values: vec![ConstF64::new(1.2).into()], + typ: usize_t(), + }; assert!(wrong_array_value.validate().is_err()); } }