Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Expand test of instantiating extension sets #1003

Merged
merged 4 commits into from
May 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 109 additions & 22 deletions hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@ use cool_asserts::assert_matches;
use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder,
HugrBuilder, ModuleBuilder,
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer,
};
use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T};
use crate::extension::{Extension, ExtensionId, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY};
use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, USIZE_T};
use crate::extension::{Extension, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::HugrMut;
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, Noop, Value};
use crate::ops::handle::NodeHandle;
use crate::ops::leaf::MakeTuple;
use crate::ops::{self, Noop, OpType, Value};
use crate::std_extensions::logic::test::{and_op, or_op};
use crate::std_extensions::logic::{self, NotOp};
use crate::types::type_param::{TypeArg, TypeArgError};
use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow};
use crate::{type_row, IncomingPort};
use crate::{const_extension_ids, type_row, Direction, IncomingPort, Node};

const NAT: Type = crate::extension::prelude::USIZE_T;

Expand Down Expand Up @@ -336,10 +338,12 @@ fn unregistered_extension() {
h.update_validate(&PRELUDE_REGISTRY).unwrap();
}

const_extension_ids! {
const EXT_ID: ExtensionId = "MyExt";
}
#[test]
fn invalid_types() {
let name: ExtensionId = "MyExt".try_into().unwrap();
let mut e = Extension::new(name.clone());
let mut e = Extension::new(EXT_ID);
e.add_type(
"MyContainer".into(),
vec![TypeBound::Copyable.into()],
Expand All @@ -360,7 +364,7 @@ fn invalid_types() {
let valid = Type::new_extension(CustomType::new(
"MyContainer",
vec![TypeArg::Type { ty: USIZE_T }],
name.clone(),
EXT_ID,
TypeBound::Any,
));
assert_eq!(
Expand All @@ -374,7 +378,7 @@ fn invalid_types() {
let element_outside_bound = CustomType::new(
"MyContainer",
vec![TypeArg::Type { ty: valid.clone() }],
name.clone(),
EXT_ID,
TypeBound::Any,
);
assert_eq!(
Expand All @@ -388,7 +392,7 @@ fn invalid_types() {
let bad_bound = CustomType::new(
"MyContainer",
vec![TypeArg::Type { ty: USIZE_T }],
name.clone(),
EXT_ID,
TypeBound::Copyable,
);
assert_eq!(
Expand All @@ -405,7 +409,7 @@ fn invalid_types() {
vec![TypeArg::Type {
ty: Type::new_extension(bad_bound),
}],
name.clone(),
EXT_ID,
TypeBound::Any,
);
assert_eq!(
Expand All @@ -419,7 +423,7 @@ fn invalid_types() {
let too_many_type_args = CustomType::new(
"MyContainer",
vec![TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 3 }],
name.clone(),
EXT_ID,
TypeBound::Any,
);
assert_eq!(
Expand Down Expand Up @@ -544,18 +548,101 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> {

#[test]
fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
let mut m = ModuleBuilder::new();
let id = m.declare(
"id",
let mut e = Extension::new(EXT_ID);

let params: Vec<TypeParam> = vec![
TypeBound::Any.into(),
TypeParam::Extensions,
TypeBound::Any.into(),
];
let evaled_fn = Type::new_function(
FunctionType::new(
Type::new_var_use(0, TypeBound::Any),
Type::new_var_use(2, TypeBound::Any),
)
.with_extension_delta(ExtensionSet::type_var(1)),
);
// The higher-order "eval" operation - takes a function and its argument.
// Note the extension-delta of the eval node includes that of the input function.
e.add_op(
"eval".into(),
"".into(),
PolyFuncType::new(
vec![TypeBound::Any.into()],
FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]),
params.clone(),
FunctionType::new(
vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)],
Type::new_var_use(2, TypeBound::Any),
)
.with_extension_delta(ExtensionSet::type_var(1)),
),
)?;
let mut f = m.define_function("main", FunctionType::new_endo(vec![USIZE_T]).into())?;
let c = f.call(&id, &[USIZE_T.into()], f.input_wires(), &PRELUDE_REGISTRY)?;
f.finish_with_outputs(c.outputs())?;
let _ = m.finish_prelude_hugr()?;

fn utou(e: impl Into<ExtensionSet>) -> Type {
Type::new_function(FunctionType::new_endo(USIZE_T).with_extension_delta(e.into()))
}

let int_pair = Type::new_tuple(type_row![USIZE_T; 2]);
// Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints
let mut d = DFGBuilder::new(
FunctionType::new(
vec![utou(PRELUDE_ID), int_pair.clone()],
vec![int_pair.clone()],
)
.with_extension_delta(PRELUDE_ID),
)?;
// ....by calling a function parametrized<extensions E> (int--e-->int, int_pair) -> int_pair
let f = {
let es = ExtensionSet::type_var(0);
let mut f = d.define_function(
"two_ints",
PolyFuncType::new(
vec![TypeParam::Extensions],
FunctionType::new(vec![utou(es.clone()), int_pair.clone()], int_pair.clone())
.with_extension_delta(es.clone()),
),
)?;
let [func, tup] = f.input_wires_arr();
let mut c = f.conditional_builder(
(vec![type_row![USIZE_T; 2]], tup),
vec![],
type_row![USIZE_T;2],
es.clone(),
)?;
let mut cc = c.case_builder(0)?;
let [i1, i2] = cc.input_wires_arr();
let op = e.instantiate_extension_op(
"eval",
vec![USIZE_T.into(), TypeArg::Extensions { es }, USIZE_T.into()],
&PRELUDE_REGISTRY,
)?;
let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr();
let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr();
cc.finish_with_outputs([f1, f2])?;
let res = c.finish_sub_container()?.outputs();
let tup = f.add_dataflow_op(
MakeTuple {
tys: type_row![USIZE_T; 2],
},
res,
)?;
f.finish_with_outputs(tup.outputs())?
};

let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?;
let [func, tup] = d.input_wires_arr();
let call = d.call(
f.handle(),
&[TypeArg::Extensions {
es: ExtensionSet::singleton(&PRELUDE_ID),
}],
[func, tup],
&reg,
)?;
let h = d.finish_hugr_with_outputs(call.outputs(), &reg)?;
let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap();
let exp_fun_ty = FunctionType::new(vec![utou(PRELUDE_ID), int_pair.clone()], int_pair)
.with_extension_delta(PRELUDE_ID);
assert_eq!(call_ty, exp_fun_ty);
Ok(())
}

Expand Down