From 6be3ca2612db40ebe005d58a089d86b3832aacb4 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 5 Mar 2024 14:00:03 +0000 Subject: [PATCH] refactor: Add `From` conversion from ExtensionId to ExtensionSet (#855) This case comes up too often to keep writing out `ExtensionSet::singleton` --- src/builder/build_traits.rs | 17 +++--- src/builder/conditional.rs | 4 +- src/builder/dataflow.rs | 13 +++-- src/builder/tail_loop.rs | 2 +- src/extension.rs | 12 +++-- src/extension/infer/test.rs | 86 ++++++++++++------------------- src/hugr.rs | 2 +- src/hugr/rewrite/inline_dfg.rs | 15 +++--- src/hugr/validate/test.rs | 26 ++++------ src/hugr/views/tests.rs | 8 +-- src/ops/constant.rs | 11 ++-- src/ops/controlflow.rs | 3 +- src/ops/dataflow.rs | 1 - src/ops/leaf.rs | 6 +-- src/std_extensions/collections.rs | 2 +- src/types/signature.rs | 4 +- src/values.rs | 5 +- 17 files changed, 95 insertions(+), 122 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 41af807c0..7b1b4621e 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -71,7 +71,7 @@ pub trait Container { /// This function will return an error if there is an error in adding the /// [`OpType::Const`] node. fn add_constant(&mut self, constant: impl Into) -> Result { - let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?; + let const_n = self.add_child_node(NodeType::new_pure(constant.into()))?; Ok(const_n.into()) } @@ -89,13 +89,10 @@ pub trait Container { signature: PolyFuncType, ) -> Result, BuildError> { let body = signature.body().clone(); - let f_node = self.add_child_node(NodeType::new( - ops::FuncDefn { - name: name.into(), - signature, - }, - ExtensionSet::new(), - ))?; + let f_node = self.add_child_node(NodeType::new_pure(ops::FuncDefn { + name: name.into(), + signature, + }))?; let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, Some(ExtensionSet::new()))?; @@ -335,9 +332,9 @@ pub trait Dataflow: Container { NodeType::new( ops::CFG { signature: FunctionType::new(inputs.clone(), output_types.clone()) - .with_extension_delta(&extension_delta), + .with_extension_delta(extension_delta), }, - input_extensions, + input_extensions.into(), ), input_wires, )?; diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 1e3441968..426f80512 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -122,7 +122,7 @@ impl + AsRef> ConditionalBuilder { let outputs = cond.outputs; let case_op = ops::Case { signature: FunctionType::new(inputs.clone(), outputs.clone()) - .with_extension_delta(&extension_delta), + .with_extension_delta(extension_delta.clone()), }; let case_node = // add case before any existing subsequent cases @@ -137,7 +137,7 @@ impl + AsRef> ConditionalBuilder { let dfg_builder = DFGBuilder::create_with_io( self.hugr_mut(), case_node, - FunctionType::new(inputs, outputs).with_extension_delta(&extension_delta), + FunctionType::new(inputs, outputs).with_extension_delta(extension_delta), None, )?; diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 879b43f93..6f0cf003f 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -428,22 +428,21 @@ pub(crate) mod test { fn lift_node() -> Result<(), BuildError> { let xa: ExtensionId = "A".try_into().unwrap(); let xb: ExtensionId = "B".try_into().unwrap(); - let xc = "C".try_into().unwrap(); + let xc: ExtensionId = "C".try_into().unwrap(); let ab_extensions = ExtensionSet::from_iter([xa.clone(), xb.clone()]); - let c_extensions = ExtensionSet::singleton(&xc); - let abc_extensions = ab_extensions.clone().union(&c_extensions); + let abc_extensions = ab_extensions.clone().union(&xc.clone().into()); let parent_sig = - FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&abc_extensions); + FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(abc_extensions); let mut parent = DFGBuilder::new(parent_sig)?; let add_c_sig = - FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&c_extensions); + FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(xc.clone()); let [w] = parent.input_wires_arr(); - let add_ab_sig = - FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&ab_extensions); + let add_ab_sig = FunctionType::new(type_row![BIT], type_row![BIT]) + .with_extension_delta(ab_extensions.clone()); // A box which adds extensions A and B, via child Lift nodes let mut add_ab = parent.dfg_builder(add_ab_sig, Some(ExtensionSet::new()), [w])?; diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 8f99ee512..786ae9ced 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -127,7 +127,7 @@ mod test { let mut fbuild = module_builder.define_function( "main", FunctionType::new(type_row![BIT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID)) + .with_extension_delta(PRELUDE_ID) .into(), )?; let _fdef = { diff --git a/src/extension.rs b/src/extension.rs index 6f2859e79..07b8d780e 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -288,14 +288,14 @@ pub struct Extension { impl Extension { /// Creates a new extension with the given name. pub fn new(name: ExtensionId) -> Self { - Self::new_with_reqs(name, Default::default()) + Self::new_with_reqs(name, ExtensionSet::default()) } /// Creates a new extension with the given name and requirements. - pub fn new_with_reqs(name: ExtensionId, extension_reqs: ExtensionSet) -> Self { + pub fn new_with_reqs(name: ExtensionId, extension_reqs: impl Into) -> Self { Self { name, - extension_reqs, + extension_reqs: extension_reqs.into(), types: Default::default(), values: Default::default(), operations: Default::default(), @@ -502,6 +502,12 @@ impl ExtensionSet { } } +impl From for ExtensionSet { + fn from(id: ExtensionId) -> Self { + Self::singleton(&id) + } +} + fn as_typevar(e: &ExtensionId) -> Option { // Type variables are represented as radix-10 numbers, which are illegal // as standard ExtensionIds. Hence if an ExtensionId starts with a digit, diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 2a4100464..c9b36ce40 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -40,7 +40,7 @@ const_extension_ids! { // them. fn from_graph() -> Result<(), Box> { let rs = ExtensionSet::from_iter([A, B, C]); - let main_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(&rs); + let main_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(rs); let op = ops::DFG { signature: main_sig, @@ -57,17 +57,14 @@ fn from_graph() -> Result<(), Box> { assert_matches!(hugr.get_io(hugr.root()), Some(_)); - let add_a_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&A)); + let add_a_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A); - let add_b_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&B)); + let add_b_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(B); let add_ab_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::from_iter([A, B])); + .with_extension_delta(ExtensionSet::from_iter([A, B])); - let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&C)); + let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(C); let add_a = hugr.add_node_with_parent( hugr.root(), @@ -128,16 +125,10 @@ fn plus() -> Result<(), InferExtensionError> { }) .collect(); - ctx.solved.insert(metas[2], ExtensionSet::singleton(&A)); + ctx.solved.insert(metas[2], A.into()); ctx.add_constraint(metas[1], Constraint::Equal(metas[2])); - ctx.add_constraint( - metas[0], - Constraint::Plus(ExtensionSet::singleton(&B), metas[2]), - ); - ctx.add_constraint( - metas[4], - Constraint::Plus(ExtensionSet::singleton(&C), metas[0]), - ); + ctx.add_constraint(metas[0], Constraint::Plus(B.into(), metas[2])); + ctx.add_constraint(metas[4], Constraint::Plus(C.into(), metas[0])); ctx.add_constraint(metas[3], Constraint::Equal(metas[4])); ctx.add_constraint(metas[5], Constraint::Equal(metas[0])); ctx.main_loop()?; @@ -164,8 +155,7 @@ fn plus() -> Result<(), InferExtensionError> { // because of a missing lift node fn missing_lift_node() -> Result<(), Box> { let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&A)), + signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A), })); let input = hugr.add_node_with_parent( @@ -211,8 +201,8 @@ fn open_variables() -> Result<(), InferExtensionError> { .insert((NodeIndex::new(4).into(), Direction::Incoming), ab); ctx.variables.insert(a); ctx.variables.insert(b); - ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&A), b)); - ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&B), a)); + ctx.add_constraint(ab, Constraint::Plus(A.into(), b)); + ctx.add_constraint(ab, Constraint::Plus(B.into(), a)); let solution = ctx.main_loop()?; // We'll only find concrete solutions for the Incoming extension reqs of // the main node created by `Hugr::default` @@ -227,11 +217,12 @@ fn dangling_src() -> Result<(), Box> { let rs = ExtensionSet::singleton(&"R".try_into().unwrap()); let mut hugr = closed_dfg_root_hugr( - FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs), + FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(rs.clone()), ); let [input, output] = hugr.get_io(hugr.root()).unwrap(); - let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); + let add_r_sig = + FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(rs.clone()); let add_r = hugr.add_node_with_parent( hugr.root(), @@ -241,8 +232,7 @@ fn dangling_src() -> Result<(), Box> { )?; // Dangling thingy - let src_sig = - FunctionType::new(type_row![], type_row![NAT]).with_extension_delta(&ExtensionSet::new()); + let src_sig = FunctionType::new(type_row![], type_row![NAT]); let src = hugr.add_node_with_parent(hugr.root(), ops::DFG { signature: src_sig })?; @@ -365,7 +355,7 @@ fn test_conditional_inference() -> Result<(), Box> { let conditional_node = hugr.root(); let case_op = ops::Case { - signature: FunctionType::new(inputs, outputs).with_extension_delta(&rs), + signature: FunctionType::new(inputs, outputs).with_extension_delta(rs), }; let case0_node = build_case(&mut hugr, conditional_node, case_op.clone(), A, B)?; @@ -393,7 +383,7 @@ fn extension_adding_sequence() -> Result<(), Box> { let mut hugr = Hugr::new(NodeType::new_open(ops::DFG { signature: df_sig .clone() - .with_extension_delta(&ExtensionSet::from_iter([A, B])), + .with_extension_delta(ExtensionSet::from_iter([A, B])), })); let root = hugr.root(); @@ -414,9 +404,7 @@ fn extension_adding_sequence() -> Result<(), Box> { let df_nodes: Vec = vec![A, A, B, B, A, B] .into_iter() .map(|ext| { - let dfg_sig = df_sig - .clone() - .with_extension_delta(&ExtensionSet::singleton(&ext)); + let dfg_sig = df_sig.clone().with_extension_delta(ext.clone()); let [node, input, output] = create_with_io( &mut hugr, root, @@ -468,7 +456,7 @@ fn make_block( let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone()); let dfb_sig = FunctionType::new(inputs.clone(), vec![tuple_sum_type]) - .with_extension_delta(&extension_delta.clone()); + .with_extension_delta(extension_delta.clone()); let dfb = ops::DataflowBlock { inputs, other_outputs: type_row![], @@ -554,14 +542,11 @@ fn create_entry_exit( /// +-------------------------+ #[test] fn infer_cfg_test() -> Result<(), Box> { - let a = ExtensionSet::singleton(&A); let abc = ExtensionSet::from_iter([A, B, C]); let bc = ExtensionSet::from_iter([B, C]); - let b = ExtensionSet::singleton(&B); - let c = ExtensionSet::singleton(&C); let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc), + signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(abc), })); let root = hugr.root(); @@ -571,7 +556,7 @@ fn infer_cfg_test() -> Result<(), Box> { root, type_row![NAT], vec![type_row![NAT], type_row![NAT]], - a.clone(), + A.into(), type_row![NAT], )?; @@ -579,7 +564,7 @@ fn infer_cfg_test() -> Result<(), Box> { entry, make_opaque( A, - FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a), + FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(A), ), )?; @@ -600,7 +585,7 @@ fn infer_cfg_test() -> Result<(), Box> { root, type_row![NAT], vec![type_row![NAT], type_row![NAT]], - b.clone(), + B.into(), )?; let bb10 = make_block( @@ -608,7 +593,7 @@ fn infer_cfg_test() -> Result<(), Box> { root, type_row![NAT], vec![type_row![NAT]], - c.clone(), + C.into(), )?; let bb11 = make_block( @@ -616,7 +601,7 @@ fn infer_cfg_test() -> Result<(), Box> { root, type_row![NAT], vec![type_row![NAT]], - c.clone(), + C.into(), )?; // CFG Wiring @@ -743,7 +728,7 @@ fn make_looping_cfg( let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&hugr_delta), + .with_extension_delta(hugr_delta), })); let root = hugr.root(); @@ -761,7 +746,7 @@ fn make_looping_cfg( entry, make_opaque( UNKNOWN_EXTENSION, - FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext), + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(entry_ext), ), )?; @@ -818,10 +803,9 @@ fn simple_cfg_loop() -> Result<(), Box> { let mut hugr = Hugr::new(NodeType::new( ops::CFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&just_a), + signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A), }, - just_a.clone(), + Some(A.into()), )); let root = hugr.root(); @@ -865,8 +849,7 @@ fn simple_cfg_loop() -> Result<(), Box> { #[test] fn plus_on_self() -> Result<(), Box> { let ext = ExtensionId::new("unknown1").unwrap(); - let delta = ExtensionSet::singleton(&ext); - let ft = FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(&delta); + let ft = FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(ext.clone()); let mut dfg = DFGBuilder::new(ft.clone())?; // While https://github.com/CQCL/hugr/issues/388 is unsolved, @@ -880,8 +863,7 @@ fn plus_on_self() -> Result<(), Box> { ft, )) .into(); - let unary_sig = FunctionType::new_endo(type_row![QB_T]) - .with_extension_delta(&ExtensionSet::singleton(&ext)); + let unary_sig = FunctionType::new_endo(type_row![QB_T]).with_extension_delta(ext.clone()); let unop: LeafOp = ExternalOp::Opaque(OpaqueOp::new( ext, "1qb_op", @@ -957,7 +939,7 @@ fn simple_funcdefn() -> Result<(), Box> { let mut func_builder = builder.define_function( "F", FunctionType::new(vec![NAT], vec![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&A)) + .with_extension_delta(A) .into(), )?; @@ -982,7 +964,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box> { let mut func_builder = builder.define_function( "F", FunctionType::new(vec![NAT], vec![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&A)) + .with_extension_delta(A) .into(), )?; @@ -1017,7 +999,7 @@ fn funcdefn_signature_mismatch2() -> Result<(), Box> { let func_builder = builder.define_function( "F", FunctionType::new(vec![NAT], vec![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&A)) + .with_extension_delta(A) .into(), )?; diff --git a/src/hugr.rs b/src/hugr.rs index 83a919abd..4a3ca1f40 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -388,7 +388,7 @@ mod test { let r = ExtensionSet::singleton(&"R".try_into().unwrap()); let mut hugr = closed_dfg_root_hugr( - FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r), + FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(r.clone()), ); let [input, output] = hugr.get_io(hugr.root()).unwrap(); let lift = hugr.add_node_with_parent( diff --git a/src/hugr/rewrite/inline_dfg.rs b/src/hugr/rewrite/inline_dfg.rs index 9d79840a9..0d576e292 100644 --- a/src/hugr/rewrite/inline_dfg.rs +++ b/src/hugr/rewrite/inline_dfg.rs @@ -179,7 +179,7 @@ mod test { let mut outer = DFGBuilder::new( FunctionType::new(vec![int_ty.clone(); 2], vec![int_ty.clone()]) - .with_extension_delta(&delta), + .with_extension_delta(delta.clone()), )?; let [a, b] = outer.input_wires_arr(); fn make_const + AsRef>( @@ -207,7 +207,7 @@ mod test { let c1 = nonlocal.then(|| make_const(&mut outer)); let inner = { let mut inner = outer.dfg_builder( - FunctionType::new_endo(vec![int_ty.clone()]).with_extension_delta(&delta), + FunctionType::new_endo(vec![int_ty.clone()]).with_extension_delta(delta), None, [a], )?; @@ -260,9 +260,8 @@ mod test { #[test] fn permutation() -> Result<(), Box> { let mut h = DFGBuilder::new( - FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta( - &ExtensionSet::singleton(&test_quantum_extension::EXTENSION_ID), - ), + FunctionType::new_endo(type_row![QB_T, QB_T]) + .with_extension_delta(test_quantum_extension::EXTENSION_ID), )?; let [p, q] = h.input_wires_arr(); let [p_h] = h @@ -355,7 +354,6 @@ mod test { * \ / * CX */ - let delta = ExtensionSet::from_iter([float_types::EXTENSION_ID]); // Extension inference here relies on quantum ops not requiring their own test_quantum_extension let reg = ExtensionRegistry::try_new([ test_quantum_extension::EXTENSION.to_owned(), @@ -364,13 +362,14 @@ mod test { ]) .unwrap(); let mut outer = DFGBuilder::new( - FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(&delta), + FunctionType::new_endo(type_row![QB_T, QB_T]) + .with_extension_delta(float_types::EXTENSION_ID), )?; let [a, b] = outer.input_wires_arr(); let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?; let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?; let mut inner = outer.dfg_builder( - FunctionType::new_endo(type_row![QB_T]).with_extension_delta(&delta), + FunctionType::new_endo(type_row![QB_T]).with_extension_delta(float_types::EXTENSION_ID), None, h_b.outputs(), )?; diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index ddad44cf2..5f8210807 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -6,9 +6,7 @@ use crate::builder::{ BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, }; use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; -use crate::extension::{ - Extension, ExtensionId, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY, -}; +use crate::extension::{Extension, ExtensionId, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrError, HugrMut, NodeType}; use crate::ops::dataflow::IOTrait; @@ -533,12 +531,12 @@ fn no_polymorphic_consts() -> Result<(), Box> { .instantiate(vec![TypeArg::new_var_use(0, BOUND)])?, ); let reg = ExtensionRegistry::try_new([collections::EXTENSION.to_owned()]).unwrap(); - let just_colns = ExtensionSet::singleton(&collections::EXTENSION_NAME); let mut def = FunctionBuilder::new( "myfunc", PolyFuncType::new( [BOUND], - FunctionType::new(vec![], vec![list_of_var.clone()]).with_extension_delta(&just_colns), + FunctionType::new(vec![], vec![list_of_var.clone()]) + .with_extension_delta(collections::EXTENSION_NAME), ), )?; let empty_list = Value::Extension { @@ -565,6 +563,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { mod extension_tests { use super::*; use crate::builder::ModuleBuilder; + use crate::extension::ExtensionSet; use crate::macros::const_extension_ids; const_extension_ids! { @@ -730,7 +729,7 @@ mod extension_tests { ops::Output { types: type_row![USIZE_T], }, - ExtensionSet::singleton(&XB), + Some(XB.into()), ), ) .unwrap(); @@ -803,8 +802,7 @@ mod extension_tests { let mut main = module_builder.define_function("main", main_sig)?; let [main_input] = main.input_wires_arr(); - let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&XA)); + let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(XA); let f_builder = main.dfg_builder(inner_sig, Some(ExtensionSet::new()), [main_input])?; let f_inputs = f_builder.input_wires(); @@ -832,7 +830,7 @@ mod extension_tests { let all_rs = ExtensionSet::from_iter([XA, XB]); let main_sig = FunctionType::new(type_row![], type_row![NAT]) - .with_extension_delta(&all_rs) + .with_extension_delta(all_rs.clone()) .into(); let mut main = module_builder.define_function("main", main_sig)?; @@ -840,7 +838,7 @@ mod extension_tests { let [left_wire] = main .dfg_builder( FunctionType::new(type_row![], type_row![NAT]), - Some(ExtensionSet::singleton(&XA)), + Some(XA.into()), [], )? .finish_with_outputs([])? @@ -849,7 +847,7 @@ mod extension_tests { let [right_wire] = main .dfg_builder( FunctionType::new(type_row![], type_row![NAT]), - Some(ExtensionSet::singleton(&XB)), + Some(XB.into()), [], )? .finish_with_outputs([])? @@ -876,10 +874,8 @@ mod extension_tests { #[test] fn parent_signature_mismatch() -> Result<(), BuildError> { - let rs = ExtensionSet::singleton(&XA); - let main_signature = - FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); + FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(XA); let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { signature: main_signature, @@ -896,7 +892,7 @@ mod extension_tests { ops::Output { types: type_row![NAT], }, - rs, + Some(XA.into()), ), )?; hugr.connect(input, 0, output, 0)?; diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index ce0353d48..bdb478b13 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -129,14 +129,10 @@ fn value_types() { #[test] fn static_targets() { - use crate::extension::{ - prelude::{ConstUsize, PRELUDE_ID, USIZE_T}, - ExtensionSet, - }; + use crate::extension::prelude::{ConstUsize, PRELUDE_ID, USIZE_T}; use itertools::Itertools; let mut dfg = DFGBuilder::new( - FunctionType::new(type_row![], type_row![USIZE_T]) - .with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID)), + FunctionType::new(type_row![], type_row![USIZE_T]).with_extension_delta(PRELUDE_ID), ) .unwrap(); diff --git a/src/ops/constant.rs b/src/ops/constant.rs index edd569fcd..07162673d 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -142,7 +142,7 @@ mod test { builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, extension::{ prelude::{ConstUsize, USIZE_CUSTOM_T, USIZE_T}, - ExtensionId, ExtensionRegistry, ExtensionSet, PRELUDE, + ExtensionId, ExtensionRegistry, PRELUDE, }, std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, type_row, @@ -246,12 +246,9 @@ mod test { ex_id.clone(), TypeBound::Eq, ); - let val: Value = CustomSerialized::new( - typ_int.clone(), - YamlValue::Number(6.into()), - ExtensionSet::singleton(&ex_id), - ) - .into(); + let val: Value = + CustomSerialized::new(typ_int.clone(), YamlValue::Number(6.into()), ex_id.clone()) + .into(); let classic_t = Type::new_extension(typ_int.clone()); assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq); classic_t.check_type(&val).unwrap(); diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 0214ec15f..a82bd148e 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -79,7 +79,8 @@ impl DataflowOpTrait for Conditional { inputs .to_mut() .insert(0, Type::new_tuple_sum(self.tuple_sum_rows.clone())); - FunctionType::new(inputs, self.outputs.clone()).with_extension_delta(&self.extension_delta) + FunctionType::new(inputs, self.outputs.clone()) + .with_extension_delta(self.extension_delta.clone()) } } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index 1dab51642..74ef64dde 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -96,7 +96,6 @@ impl DataflowOpTrait for Input { fn signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), self.types.clone()) - .with_extension_delta(&ExtensionSet::new()) } } impl DataflowOpTrait for Output { diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 437124db7..4975926f6 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -6,11 +6,11 @@ use super::custom::{ExtensionOp, ExternalOp}; use super::dataflow::DataflowOpTrait; use super::{OpName, OpTag}; -use crate::extension::{ExtensionRegistry, SignatureError}; +use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; use crate::types::type_param::TypeArg; use crate::types::PolyFuncType; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::ExtensionId, types::{EdgeKind, FunctionType, Type, TypeRow}, }; @@ -185,7 +185,7 @@ impl DataflowOpTrait for LeafOp { type_row, new_extension, } => FunctionType::new(type_row.clone(), type_row.clone()) - .with_extension_delta(&ExtensionSet::singleton(new_extension)), + .with_extension_delta(ExtensionSet::singleton(new_extension)), LeafOp::TypeApply { ta } => FunctionType::new( vec![Type::new_function(ta.input.clone())], vec![Type::new_function(ta.output.clone())], diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 1575b75f8..be7c3f430 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -91,7 +91,7 @@ impl CustomConst for ListValue { fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) - .union(&ExtensionSet::singleton(&EXTENSION_NAME)) + .union(&(EXTENSION_NAME.into())) } } diff --git a/src/types/signature.rs b/src/types/signature.rs index ebcd44a7c..e5ca5a56e 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -24,8 +24,8 @@ pub struct FunctionType { impl FunctionType { /// Builder method, add extension_reqs to an FunctionType - pub fn with_extension_delta(mut self, rs: &ExtensionSet) -> Self { - self.extension_reqs = self.extension_reqs.union(rs); + pub fn with_extension_delta(mut self, rs: impl Into) -> Self { + self.extension_reqs = self.extension_reqs.union(&rs.into()); self } diff --git a/src/values.rs b/src/values.rs index 974852a91..554e6ebd8 100644 --- a/src/values.rs +++ b/src/values.rs @@ -194,7 +194,8 @@ pub struct CustomSerialized { impl CustomSerialized { /// Creates a new [`CustomSerialized`]. - pub fn new(typ: CustomType, value: serde_yaml::Value, extensions: ExtensionSet) -> Self { + pub fn new(typ: CustomType, value: serde_yaml::Value, exts: impl Into) -> Self { + let extensions = exts.into(); Self { typ, value, @@ -260,7 +261,7 @@ pub(crate) mod test { Value::custom(CustomSerialized { typ: FLOAT64_CUSTOM_TYPE, value: serde_yaml::Value::Number(f.into()), - extensions: ExtensionSet::singleton(&float_types::EXTENSION_ID), + extensions: float_types::EXTENSION_ID.into(), }) }