From d429cffc8a5a6a10af44b701aca772622c862eb6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 17 Dec 2024 15:34:57 +0000 Subject: [PATCH] feat: add ArrayValue to python, rust and lowering (#1773) drive-by: fix Array type annotation Closes #1771 TODO: - [x] Rust ArrayValue - [x] LLVM lowering --------- Co-authored-by: Mark Koch --- .../src/std_extensions/collections/array.rs | 177 +++++++++++++++++- .../src/std_extensions/collections/list.rs | 2 +- hugr-llvm/src/extension/collections/array.rs | 56 ++++++ ..._array__test__emit_array_value@llvm14.snap | 14 ++ ...__emit_array_value@pre-mem2reg@llvm14.snap | 20 ++ hugr-py/src/hugr/std/collections/array.py | 27 ++- hugr-py/tests/test_tys.py | 6 +- 7 files changed, 288 insertions(+), 14 deletions(-) create mode 100644 hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap create mode 100644 hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index b38e762b5..c1ccfe57e 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -6,13 +6,21 @@ mod array_scan; use std::sync::Arc; +use itertools::Itertools as _; use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use std::hash::{Hash, Hasher}; +use crate::extension::resolution::{ + resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError, + WeakExtensionRegistry, +}; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; -use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; -use crate::ops::{ExtensionOp, OpName}; +use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}; +use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName}; +use crate::ops::{ExtensionOp, OpName, Value}; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{Type, TypeBound, TypeName}; +use crate::types::{CustomCheckFailure, CustomType, Type, TypeBound, TypeName}; use crate::Extension; pub use array_op::{ArrayOp, ArrayOpDef, ArrayOpDefIter}; @@ -26,8 +34,128 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.ar /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +/// Statically sized array of values, all of the same type. +pub struct ArrayValue { + values: Vec, + typ: Type, +} + +impl ArrayValue { + /// Create a new [CustomConst] for an array of values of type `typ`. + /// That all values are of type `typ` is not checked here. + pub fn new(typ: Type, contents: impl IntoIterator) -> Self { + Self { + values: contents.into_iter().collect_vec(), + typ, + } + } + + /// Create a new [CustomConst] for an empty array of values of type `typ`. + pub fn new_empty(typ: Type) -> Self { + Self { + values: vec![], + typ, + } + } + + /// Returns the type of the `[ArrayValue]` as a `[CustomType]`.` + pub fn custom_type(&self) -> CustomType { + array_custom_type(self.values.len() as u64, self.typ.clone()) + } + + /// Returns the type of values inside the `[ArrayValue]`. + pub fn get_element_type(&self) -> &Type { + &self.typ + } + + /// Returns the values contained inside the `[ArrayValue]`. + pub fn get_contents(&self) -> &[Value] { + &self.values + } +} + +impl TryHash for ArrayValue { + fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { + maybe_hash_values(&self.values, &mut st) && { + self.typ.hash(&mut st); + true + } + } +} + +#[typetag::serde] +impl CustomConst for ArrayValue { + fn name(&self) -> ValueName { + ValueName::new_inline("array") + } + + fn get_type(&self) -> Type { + self.custom_type().into() + } + + fn validate(&self) -> Result<(), CustomCheckFailure> { + let typ = self.custom_type(); + + EXTENSION + .get_type(&ARRAY_TYPENAME) + .unwrap() + .check_custom(&typ) + .map_err(|_| { + CustomCheckFailure::Message(format!( + "Custom typ {typ} is not a valid instantiation of array." + )) + })?; + + // constant can only hold classic type. + let ty = match typ.args() { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] + if *n as usize == self.values.len() => + { + ty + } + _ => { + return Err(CustomCheckFailure::Message(format!( + "Invalid array type arguments: {:?}", + typ.args() + ))) + } + }; + + // check all values are instances of the element type + for v in &self.values { + if v.get_type() != *ty { + return Err(CustomCheckFailure::Message(format!( + "Array element {v:?} is not of expected type {ty}" + ))); + } + } + + Ok(()) + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs)) + .union(EXTENSION_ID.into()) + } + + fn update_extensions( + &mut self, + extensions: &WeakExtensionRegistry, + ) -> Result<(), ExtensionResolutionError> { + for val in &mut self.values { + resolve_value_extensions(val, extensions)?; + } + resolve_type_extensions(&mut self.typ, extensions) + } +} + lazy_static! { - /// Extension for list operations. + /// Extension for array operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { extension.add_type( @@ -55,7 +183,7 @@ fn array_type_def() -> &'static TypeDef { /// This method is equivalent to [`array_type_parametric`], but uses concrete /// arguments types to ensure no errors are possible. pub fn array_type(size: u64, element_ty: Type) -> Type { - instantiate_array(array_type_def(), size, element_ty).expect("array parameters are valid") + array_custom_type(size, element_ty).into() } /// Instantiate a new array type given the size and element type parameters. @@ -68,14 +196,25 @@ pub fn array_type_parametric( instantiate_array(array_type_def(), size, element_ty) } +fn array_custom_type(size: impl Into, element_ty: impl Into) -> CustomType { + instantiate_array_custom(array_type_def(), size, element_ty) + .expect("array parameters are valid") +} + +fn instantiate_array_custom( + array_def: &TypeDef, + size: impl Into, + element_ty: impl Into, +) -> Result { + array_def.instantiate(vec![size.into(), element_ty.into()]) +} + fn instantiate_array( array_def: &TypeDef, size: impl Into, element_ty: impl Into, ) -> Result { - array_def - .instantiate(vec![size.into(), element_ty.into()]) - .map(Into::into) + instantiate_array_custom(array_def, size, element_ty).map(Into::into) } /// Name of the operation in the prelude for creating new arrays. @@ -90,9 +229,11 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp { #[cfg(test)] mod test { use crate::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}; - use crate::extension::prelude::qb_t; + use crate::extension::prelude::{qb_t, usize_t, ConstUsize}; + use crate::ops::constant::CustomConst; + use crate::std_extensions::arithmetic::float_types::ConstF64; - use super::{array_type, new_array_op}; + use super::{array_type, new_array_op, ArrayValue}; #[test] /// Test building a HUGR involving a new_array operation. @@ -108,4 +249,20 @@ mod test { b.finish_hugr_with_outputs(out.outputs()).unwrap(); } + + #[test] + fn test_array_value() { + let array_value = ArrayValue { + values: vec![ConstUsize::new(3).into()], + typ: usize_t(), + }; + + array_value.validate().unwrap(); + + let wrong_array_value = ArrayValue { + values: vec![ConstF64::new(1.2).into()], + typ: usize_t(), + }; + assert!(wrong_array_value.validate().is_err()); + } } diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 79330a68b..53e82dd51 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -49,7 +49,7 @@ pub struct ListValue(Vec, Type); impl ListValue { /// Create a new [CustomConst] for a list of values of type `typ`. - /// That all values ore of type `typ` is not checked here. + /// That all values are of type `typ` is not checked here. pub fn new(typ: Type, contents: impl IntoIterator) -> Self { Self(contents.into_iter().collect_vec(), typ) } diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 39d721a29..0cf530e9f 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -18,6 +18,7 @@ use inkwell::values::{ use inkwell::IntPredicate; use itertools::Itertools; +use crate::emit::emit_value; use crate::{ emit::{deaggregate_call_result, EmitFuncContext, RowPromise}, sum::LLVMSumType, @@ -52,6 +53,15 @@ pub trait ArrayCodegen: Clone { elem_ty.array_type(size as u32) } + /// Emit a [hugr_core::std_extensions::collections::array::ArrayValue]. + fn emit_array_value<'c, H: HugrView>( + &self, + ctx: &mut EmitFuncContext<'c, '_, H>, + value: &array::ArrayValue, + ) -> Result> { + emit_array_value(self, ctx, value) + } + /// Emit a [hugr_core::std_extensions::collections::array::ArrayOp]. fn emit_array_op<'c, H: HugrView>( &self, @@ -129,6 +139,10 @@ impl CodegenExtension for ArrayCodegenExtension { Ok(ccg.array_type(&ts, elem_ty, *n).as_basic_type_enum()) } }) + .custom_const::({ + let ccg = self.0.clone(); + move |context, k| ccg.emit_array_value(context, k) + }) .simple_extension_op::({ let ccg = self.0.clone(); move |context, args, _| { @@ -244,6 +258,31 @@ fn build_loop<'c, T, H: HugrView>( Ok(val) } +pub fn emit_array_value<'c, H: HugrView>( + ccg: &impl ArrayCodegen, + ctx: &mut EmitFuncContext<'c, '_, H>, + value: &array::ArrayValue, +) -> Result> { + let ts = ctx.typing_session(); + let llvm_array_ty = ccg + .array_type( + &ts, + ts.llvm_type(value.get_element_type())?, + value.get_contents().len() as u64, + ) + .as_basic_type_enum() + .into_array_type(); + let mut array_v = llvm_array_ty.get_undef(); + for (i, v) in value.get_contents().iter().enumerate() { + let llvm_v = emit_value(ctx, v)?; + array_v = ctx + .builder() + .build_insert_value(array_v, llvm_v, i as u32, "")? + .into_array_value(); + } + Ok(array_v.into()) +} + pub fn emit_array_op<'c, H: HugrView>( ccg: &impl ArrayCodegen, ctx: &mut EmitFuncContext<'c, '_, H>, @@ -739,6 +778,23 @@ mod test { check_emission!(hugr, llvm_ctx); } + #[rstest] + fn emit_array_value(mut llvm_ctx: TestContext) { + let hugr = SimpleHugrConfig::new() + .with_extensions(STD_REG.to_owned()) + .with_outs(vec![array_type(2, usize_t())]) + .finish(|mut builder| { + let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()]; + let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs)); + builder.finish_with_outputs([arr]).unwrap() + }); + llvm_ctx.add_extensions(|cge| { + cge.add_default_prelude_extensions() + .add_default_array_extensions() + }); + check_emission!(hugr, llvm_ctx); + } + fn exec_registry() -> ExtensionRegistry { ExtensionRegistry::new([ int_types::EXTENSION.to_owned(), diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap new file mode 100644 index 000000000..3a718f7f2 --- /dev/null +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap @@ -0,0 +1,14 @@ +--- +source: hugr-llvm/src/extension/collections/array.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define [2 x i64] @_hl.main.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + ret [2 x i64] [i64 1, i64 2] +} diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..5befaf3df --- /dev/null +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap @@ -0,0 +1,20 @@ +--- +source: hugr-llvm/src/extension/collections/array.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define [2 x i64] @_hl.main.1() { +alloca_block: + %"0" = alloca [2 x i64], align 8 + %"5_0" = alloca [2 x i64], align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store [2 x i64] [i64 1, i64 2], [2 x i64]* %"5_0", align 4 + %"5_01" = load [2 x i64], [2 x i64]* %"5_0", align 4 + store [2 x i64] %"5_01", [2 x i64]* %"0", align 4 + %"02" = load [2 x i64], [2 x i64]* %"0", align 4 + ret [2 x i64] %"02" +} diff --git a/hugr-py/src/hugr/std/collections/array.py b/hugr-py/src/hugr/std/collections/array.py index d59b00af3..f7638e4f7 100644 --- a/hugr-py/src/hugr/std/collections/array.py +++ b/hugr-py/src/hugr/std/collections/array.py @@ -4,8 +4,9 @@ from dataclasses import dataclass -import hugr.tys as tys +from hugr import tys, val from hugr.std import _load_extension +from hugr.utils import comma_sep_str EXTENSION = _load_extension("collections.array") @@ -14,7 +15,7 @@ class Array(tys.ExtType): """Fixed `size` array of `ty` elements.""" - def __init__(self, ty: tys.Type, size: int | tys.BoundedNatArg) -> None: + def __init__(self, ty: tys.Type, size: int | tys.TypeArg) -> None: if isinstance(size, int): size = tys.BoundedNatArg(size) @@ -52,3 +53,25 @@ def size(self) -> int | None: def type_bound(self) -> tys.TypeBound: return self.ty.type_bound() + + +@dataclass +class ArrayVal(val.ExtensionValue): + """Constant value for a statically sized array of elements.""" + + v: list[val.Value] + ty: tys.Type + + def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None: + self.v = v + self.ty = Array(elem_ty, len(v)) + + def to_value(self) -> val.Extension: + name = "ArrayValue" + # The value list must be serialized at this point, otherwise the + # `Extension` value would not be serializable. + vs = [v._to_serial_root() for v in self.v] + return val.Extension(name, typ=self.ty, val=vs, extensions=[EXTENSION.name]) + + def __str__(self) -> str: + return f"array({comma_sep_str(self.v)})" diff --git a/hugr-py/tests/test_tys.py b/hugr-py/tests/test_tys.py index bca301182..e12e6fb0b 100644 --- a/hugr-py/tests/test_tys.py +++ b/hugr-py/tests/test_tys.py @@ -3,7 +3,7 @@ import pytest from hugr import val -from hugr.std.collections.array import Array +from hugr.std.collections.array import Array, ArrayVal from hugr.std.collections.list import List, ListVal from hugr.std.float import FLOAT_T from hugr.std.int import INT_T, _int_tv @@ -170,3 +170,7 @@ def test_array(): ls = Array(ty_var, len_var) assert ls.ty == ty_var assert ls.size is None + + ar_val = ArrayVal([val.TRUE, val.FALSE], Bool) + assert ar_val.v == [val.TRUE, val.FALSE] + assert ar_val.ty == Array(Bool, 2)