Skip to content

Commit

Permalink
feat: fill out array ops (#1491)
Browse files Browse the repository at this point in the history
Closes #1320

---------

Co-authored-by: Agustín Borgna <[email protected]>
  • Loading branch information
ss2165 and aborgna-q authored Aug 30, 2024
1 parent b2a30a0 commit 26ec57a
Show file tree
Hide file tree
Showing 4 changed files with 1,188 additions and 59 deletions.
69 changes: 10 additions & 59 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use crate::extension::{
ConstFold, ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDefBound,
};
use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName};
use crate::ops::{ExtensionOp, NamedOp, OpName, Value};
use crate::ops::OpName;
use crate::ops::{NamedOp, Value};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{
CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound,
Expand All @@ -22,30 +23,11 @@ use crate::{type_row, Extension};

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use super::{ExtensionRegistry, SignatureFromArgs};
struct ArrayOpCustom;
use super::ExtensionRegistry;

const MAX: &[TypeParam; 1] = &[TypeParam::max_nat()];
impl SignatureFromArgs for ArrayOpCustom {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError> {
let [TypeArg::BoundedNat { n }] = *arg_values else {
return Err(SignatureError::InvalidTypeArgs);
};
let elem_ty_var = Type::new_var_use(0, TypeBound::Any);

let var_arg_row = vec![elem_ty_var.clone(); n as usize];
let other_row = vec![array_type(TypeArg::BoundedNat { n }, elem_ty_var.clone())];

Ok(PolyFuncTypeRV::new(
vec![TypeBound::Any.into()],
FuncValueType::new(var_arg_row, other_row),
))
}

fn static_params(&self) -> &[TypeParam] {
MAX
}
}
/// Array type and operations.
pub mod array;
pub use array::{array_type, new_array_op, ArrayOp, ArrayOpDef, ARRAY_TYPE_NAME, NEW_ARRAY_OP_ID};

/// Name of prelude extension.
pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude");
Expand Down Expand Up @@ -76,19 +58,12 @@ lazy_static! {
)
.unwrap();
prelude.add_type(
TypeName::new_inline("array"),
TypeName::new_inline(ARRAY_TYPE_NAME),
vec![ TypeParam::max_nat(), TypeBound::Any.into()],
"array".into(),
TypeDefBound::from_params(vec![1] ),
)
.unwrap();
prelude
.add_op(
NEW_ARRAY_OP_ID,
"Create a new array from elements".to_string(),
ArrayOpCustom,
)
.unwrap();

prelude
.add_type(
Expand Down Expand Up @@ -125,6 +100,7 @@ lazy_static! {
TupleOpDef::load_all_ops(&mut prelude).unwrap();
NoopDef.add_to_extension(&mut prelude).unwrap();
LiftDef.add_to_extension(&mut prelude).unwrap();
array::ArrayOpDef::load_all_ops(&mut prelude).unwrap();
prelude
};
/// An extension registry containing only the prelude
Expand Down Expand Up @@ -152,18 +128,6 @@ pub const USIZE_T: Type = Type::new_extension(USIZE_CUSTOM_T);
/// Boolean type - Sum of two units.
pub const BOOL_T: Type = Type::new_unit_sum(2);

/// Initialize a new array of element type `element_ty` of length `size`
pub fn array_type(size: impl Into<TypeArg>, element_ty: Type) -> Type {
let array_def = PRELUDE.get_type("array").unwrap();
let custom_t = array_def
.instantiate(vec![size.into(), element_ty.into()])
.unwrap();
Type::new_extension(custom_t)
}

/// Name of the operation in the prelude for creating new arrays.
pub const NEW_ARRAY_OP_ID: OpName = OpName::new_inline("new_array");

/// Name of the prelude panic operation.
///
/// This operation can have any input and any output wires; it is instantiated
Expand All @@ -175,20 +139,6 @@ pub const NEW_ARRAY_OP_ID: OpName = OpName::new_inline("new_array");
/// satisfied.
pub const PANIC_OP_ID: OpName = OpName::new_inline("panic");

/// Initialize a new array op of element type `element_ty` of length `size`
pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp {
PRELUDE
.instantiate_extension_op(
&NEW_ARRAY_OP_ID,
vec![
TypeArg::BoundedNat { n: size },
TypeArg::Type { ty: element_ty },
],
&PRELUDE_REGISTRY,
)
.unwrap()
}

/// Name of the string type.
pub const STRING_TYPE_NAME: TypeName = TypeName::new_inline("string");

Expand Down Expand Up @@ -934,10 +884,11 @@ impl MakeRegisteredOp for Lift {

#[cfg(test)]
mod test {
use crate::builder::inout_sig;
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::{
builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr},
builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr},
utils::test_quantum_extension::cx_gate,
Hugr, Wire,
};
Expand Down
Loading

0 comments on commit 26ec57a

Please sign in to comment.