From f1e1e6f4c69f4bc0b7c107d6947e3416ca855965 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 10:33:07 +0000 Subject: [PATCH 01/27] Add Value::extension_reqs (not used yet) --- src/extension/prelude.rs | 6 ++- src/ops/constant.rs | 14 +++++-- src/std_extensions/arithmetic/float_types.rs | 6 ++- src/std_extensions/arithmetic/int_types.rs | 10 ++++- src/std_extensions/collections.rs | 9 +++- src/values.rs | 43 ++++++++++++++++++-- 6 files changed, 77 insertions(+), 11 deletions(-) diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index c5f587975..f96046ba8 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; -use super::{ExtensionRegistry, SignatureError, SignatureFromArgs}; +use super::{ExtensionRegistry, ExtensionSet, SignatureError, SignatureFromArgs}; struct ArrayOpCustom; const MAX: &[TypeParam; 1] = &[TypeParam::max_nat()]; @@ -181,6 +181,10 @@ impl CustomConst for ConstUsize { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&PRELUDE_ID) + } } impl KnownTypeConst for ConstUsize { diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 5f87c96d2..d47c3c41c 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -159,10 +159,13 @@ mod test { let c = b.add_constant( Const::tuple_sum( 0, - Value::tuple([CustomTestValue(TypeBound::Eq).into(), serialized_float(5.1)]), + Value::tuple([ + CustomTestValue(TypeBound::Eq, ExtensionSet::new()).into(), + serialized_float(5.1), + ]), pred_rows.clone(), )?, - ExtensionSet::new(), + ExtensionSet::new(), // ALAN remove given above? )?; let w = b.load_const(&c)?; b.finish_hugr_with_outputs([w], &test_registry()).unwrap(); @@ -233,7 +236,12 @@ mod test { ex_id.clone(), TypeBound::Eq, ); - let val: Value = CustomSerialized::new(typ_int.clone(), YamlValue::Number(6.into())).into(); + let val: Value = CustomSerialized::new( + typ_int.clone(), + YamlValue::Number(6.into()), + ExtensionSet::singleton(&ex_id), + ) + .into(); let classic_t = Type::new_extension(typ_int.clone()); assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq); classic_t.check_type(&val).unwrap(); diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 32b7815ef..582c0eee7 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -3,7 +3,7 @@ use smol_str::SmolStr; use crate::{ - extension::ExtensionId, + extension::{ExtensionId, ExtensionSet}, types::{CustomCheckFailure, CustomType, Type, TypeBound}, values::{CustomConst, KnownTypeConst}, Extension, @@ -66,6 +66,10 @@ impl CustomConst for ConstF64 { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&EXTENSION_ID) + } } /// Extension for basic floating-point types. diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 7a67de28a..1ff680503 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -5,7 +5,7 @@ use std::num::NonZeroU64; use smol_str::SmolStr; use crate::{ - extension::ExtensionId, + extension::{ExtensionId, ExtensionSet}, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, @@ -161,6 +161,10 @@ impl CustomConst for ConstIntU { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&EXTENSION_ID) + } } #[typetag::serde] @@ -180,6 +184,10 @@ impl CustomConst for ConstIntS { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&EXTENSION_ID) + } } /// Extension for basic integer types. diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 8b6e227ea..fa42b1a78 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::{ - extension::{ExtensionId, TypeDef, TypeDefBound}, + extension::{ExtensionId, ExtensionSet, TypeDef, TypeDefBound}, types::{ type_param::{TypeArg, TypeParam}, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, @@ -66,6 +66,13 @@ impl CustomConst for ListValue { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + self.0 + .iter() + .map(Value::extension_reqs) + .fold(ExtensionSet::singleton(&EXTENSION_NAME), |a, b| a.union(&b)) + } } const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; diff --git a/src/values.rs b/src/values.rs index 428654066..f6215760a 100644 --- a/src/values.rs +++ b/src/values.rs @@ -8,6 +8,7 @@ use std::any::Any; use downcast_rs::{impl_downcast, Downcast}; use smol_str::SmolStr; +use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; use crate::{Hugr, HugrView}; @@ -115,6 +116,19 @@ impl Value { None } } + + /// The Extensions that must be supported to handle the value at runtime + pub fn extension_reqs(&self) -> ExtensionSet { + match self { + Value::Extension { c } => c.0.extension_reqs().clone(), + Value::Function { .. } => ExtensionSet::new(), // no extensions reqd to load Hugr (only to run) + Value::Tuple { vs } => vs + .iter() + .map(Value::extension_reqs) + .fold(ExtensionSet::new(), |a, b| a.union(&b)), + Value::Sum { value, .. } => value.extension_reqs(), + } + } } impl From for Value { @@ -134,6 +148,13 @@ pub trait CustomConst: /// An identifier for the constant. fn name(&self) -> SmolStr; + /// The extension(s) defining the custom value + /// (a set to allow, say, a [List] of [USize]) + /// + /// [List]: crate::std_extensions::collections::LIST_TYPENAME + /// [USize]: crate::extension::prelude::USIZE_T + fn extension_reqs(&self) -> ExtensionSet; + /// Check the value is a valid instance of the provided type. fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure>; @@ -184,12 +205,17 @@ impl_box_clone!(CustomConst, CustomConstBoxClone); pub struct CustomSerialized { typ: CustomType, value: serde_yaml::Value, + extensions: ExtensionSet, } impl CustomSerialized { /// Creates a new [`CustomSerialized`]. - pub fn new(typ: CustomType, value: serde_yaml::Value) -> Self { - Self { typ, value } + pub fn new(typ: CustomType, value: serde_yaml::Value, extensions: ExtensionSet) -> Self { + Self { + typ, + value, + extensions, + } } } @@ -213,6 +239,10 @@ impl CustomConst for CustomSerialized { fn equal_consts(&self, other: &dyn CustomConst) -> bool { Some(self) == other.downcast_ref() } + + fn extension_reqs(&self) -> ExtensionSet { + self.extensions.clone() + } } impl PartialEq for dyn CustomConst { @@ -227,7 +257,7 @@ pub(crate) mod test { use super::*; use crate::builder::test::simple_dfg_hugr; - use crate::std_extensions::arithmetic::float_types::FLOAT64_CUSTOM_TYPE; + use crate::std_extensions::arithmetic::float_types::{self, FLOAT64_CUSTOM_TYPE}; use crate::type_row; use crate::types::{FunctionType, Type, TypeBound}; @@ -235,7 +265,7 @@ pub(crate) mod test { /// A custom constant value used in testing that purports to be an instance /// of a custom type with a specific type bound. - pub(crate) struct CustomTestValue(pub TypeBound); + pub(crate) struct CustomTestValue(pub TypeBound, pub ExtensionSet); #[typetag::serde] impl CustomConst for CustomTestValue { fn name(&self) -> SmolStr { @@ -251,12 +281,17 @@ pub(crate) mod test { )) } } + + fn extension_reqs(&self) -> ExtensionSet { + self.1.clone() + } } pub(crate) fn serialized_float(f: f64) -> Value { Value::custom(CustomSerialized { typ: FLOAT64_CUSTOM_TYPE, value: serde_yaml::Value::Number(f.into()), + extensions: ExtensionSet::singleton(&float_types::EXTENSION_ID), }) } From cdc5785a16b44689548c1cb2f1659b5754df56a3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 10:40:53 +0000 Subject: [PATCH 02/27] Add OpTrait::extension_delta, non-empty for Const or DataflowOp --- src/extension/infer.rs | 12 ++++-------- src/hugr.rs | 13 +++---------- src/ops.rs | 8 ++++++++ src/ops/constant.rs | 5 +++++ src/ops/dataflow.rs | 3 +++ 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 0b99789c2..5d8176968 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -317,15 +317,11 @@ impl UnificationContext { match node_type.io_extensions() { // Input extensions are open None => { - let c = if let Some(sig) = node_type.op_signature() { - let delta = sig.extension_reqs; - if delta.is_empty() { - Constraint::Equal(m_input) - } else { - Constraint::Plus(delta, m_input) - } - } else { + let delta = node_type.op().extension_delta(); + let c = if delta.is_empty() { Constraint::Equal(m_input) + } else { + Constraint::Plus(delta, m_input) }; self.add_constraint(m_output, c); } diff --git a/src/hugr.rs b/src/hugr.rs index 2971885a8..9672f3dbb 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -125,16 +125,9 @@ impl NodeType { /// `None`` if the [Self::input_extensions] is `None`. /// Otherwise, will return Some, with the output extensions computed from the node's delta pub fn io_extensions(&self) -> Option<(&ExtensionSet, ExtensionSet)> { - self.input_extensions.as_ref().map(|e| { - ( - e, - self.op - .dataflow_signature() - .map(|ft| ft.extension_reqs) - .unwrap_or_default() - .union(e), - ) - }) + self.input_extensions + .as_ref() + .map(|e| (e, self.op.extension_delta().union(e))) } /// Gets the underlying [OpType] i.e. without any [input_extensions] diff --git a/src/ops.rs b/src/ops.rs index 7deb6328d..9c44f7f00 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -9,6 +9,7 @@ pub mod leaf; pub mod module; pub mod tag; pub mod validate; +use crate::extension::ExtensionSet; use crate::types::{EdgeKind, FunctionType, Type}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; @@ -278,6 +279,13 @@ pub trait OpTrait { fn dataflow_signature(&self) -> Option { None } + + /// The delta between the input extensions specified for a node, + /// and the output extensions calculated for that node + fn extension_delta(&self) -> ExtensionSet { + ExtensionSet::new() + } + /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. /// diff --git a/src/ops/constant.rs b/src/ops/constant.rs index d47c3c41c..cb4a38047 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -1,6 +1,7 @@ //! Constant value definitions. use crate::{ + extension::ExtensionSet, types::{ConstTypeError, EdgeKind, Type, TypeRow}, values::{CustomConst, KnownTypeConst, Value}, }; @@ -96,6 +97,10 @@ impl OpTrait for Const { self.value.description() } + fn extension_delta(&self) -> ExtensionSet { + self.value.extension_reqs() + } + fn tag(&self) -> OpTag { ::TAG } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index bf4fd83a9..d830fd0d2 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -115,6 +115,9 @@ impl OpTrait for T { fn dataflow_signature(&self) -> Option { Some(DataflowOpTrait::signature(self)) } + fn extension_delta(&self) -> ExtensionSet { + DataflowOpTrait::signature(self).extension_reqs.clone() + } fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) } From 9e1b3eb234f5570f9a39d5c25714ebe3b73cc230 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 12:10:05 +0000 Subject: [PATCH 03/27] Fix test_conditional_inference --- src/ops/controlflow.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index d2121cd36..afd0ef5f8 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -228,6 +228,10 @@ impl OpTrait for Case { "A case node inside a conditional" } + fn extension_delta(&self) -> ExtensionSet { + self.signature.extension_reqs.clone() + } + fn tag(&self) -> OpTag { ::TAG } From b47ad17bdea4e4a6a492db82c6aafab30c9d9426 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 12:29:46 +0000 Subject: [PATCH 04/27] Fix test_tuple_sum by build_traits.rs: when adding load_const, do not use input extensions of Const --- src/builder/build_traits.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 9f467426a..ea33001a4 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -356,20 +356,16 @@ pub trait Dataflow: Container { fn load_const(&mut self, cid: &ConstID) -> Result { let const_node = cid.node(); let nodetype = self.hugr().get_nodetype(const_node); - let input_extensions = nodetype.input_extensions().cloned(); let op: ops::Const = nodetype .op() .clone() .try_into() .expect("ConstID does not refer to Const op."); - let load_n = self.add_dataflow_node( - NodeType::new( - ops::LoadConstant { - datatype: op.const_type().clone(), - }, - input_extensions, - ), + let load_n = self.add_dataflow_op( + ops::LoadConstant { + datatype: op.const_type().clone(), + }, // Constant wire from the constant value node vec![Wire::new(const_node, OutgoingPort::from(0))], )?; From bb66a8b62275bd7964954088e07b22385ffcd706 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 12:33:56 +0000 Subject: [PATCH 05/27] Fix static_targets --- src/hugr/views/tests.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index a2a3274b9..2e57ece82 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -3,7 +3,10 @@ use rstest::{fixture, rstest}; use crate::{ builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr}, - extension::prelude::QB_T, + extension::{ + prelude::{PRELUDE_ID, QB_T}, + ExtensionSet, + }, ops::handle::{DataflowOpID, NodeHandle}, type_row, types::FunctionType, @@ -134,16 +137,19 @@ fn value_types() { fn static_targets() { use crate::extension::prelude::{ConstUsize, USIZE_T}; use itertools::Itertools; + let mut dfg = DFGBuilder::new( + FunctionType::new(type_row![], type_row![USIZE_T]) + .with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID)), + ) + .unwrap(); - let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap(); - - let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap(); + let c = dfg + .add_constant(ConstUsize::new(1).into(), ExtensionSet::new()) + .unwrap(); let load = dfg.load_const(&c).unwrap(); - let h = dfg - .finish_hugr_with_outputs([load], &crate::extension::PRELUDE_REGISTRY) - .unwrap(); + let h = dfg.finish_prelude_hugr_with_outputs([load]).unwrap(); assert_eq!(h.static_source(load.node()), Some(c.node())); From 35e1787d5ed5b9c5a89e84588e3de5664bc6fd43 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 12:41:48 +0000 Subject: [PATCH 06/27] add_constant drop ExtensionSet parameter - always empty --- src/algorithm/nest_cfgs.rs | 12 ++++++------ src/builder/build_traits.rs | 16 ++++------------ src/builder/cfg.rs | 8 +++----- src/builder/conditional.rs | 2 +- src/builder/tail_loop.rs | 13 +++---------- src/hugr/rewrite/outline_cfg.rs | 2 +- src/hugr/rewrite/replace.rs | 2 +- src/hugr/validate/test.rs | 2 +- src/hugr/views/tests.rs | 4 +--- src/ops/constant.rs | 24 +++++++++--------------- 10 files changed, 30 insertions(+), 55 deletions(-) diff --git a/src/algorithm/nest_cfgs.rs b/src/algorithm/nest_cfgs.rs index 15154d4f2..70d5aa26c 100644 --- a/src/algorithm/nest_cfgs.rs +++ b/src/algorithm/nest_cfgs.rs @@ -605,8 +605,8 @@ pub(crate) mod test { // \-> right -/ \-<--<-/ let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, @@ -887,8 +887,8 @@ pub(crate) mod test { separate: bool, ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?, @@ -929,8 +929,8 @@ pub(crate) mod test { cfg_builder: &mut CFGBuilder, separate_headers: bool, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { - let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index ea33001a4..e71c5633f 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -70,12 +70,8 @@ pub trait Container { /// /// This function will return an error if there is an error in adding the /// [`OpType::Const`] node. - fn add_constant( - &mut self, - constant: ops::Const, - extensions: impl Into>, - ) -> Result { - let const_n = self.add_child_node(NodeType::new(constant, extensions.into()))?; + fn add_constant(&mut self, constant: ops::Const) -> Result { + let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?; Ok(const_n.into()) } @@ -378,12 +374,8 @@ pub trait Dataflow: Container { /// # Errors /// /// This function will return an error if there is an error when adding the node. - fn add_load_const( - &mut self, - constant: ops::Const, - extensions: ExtensionSet, - ) -> Result { - let cid = self.add_constant(constant, extensions)?; + fn add_load_const(&mut self, constant: ops::Const) -> Result { + let cid = self.add_constant(constant)?; self.load_const(&cid) } diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 99781ea2c..9de97bc9c 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -385,7 +385,7 @@ mod test { let mut middle_b = cfg_builder .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; let middle = { - let c = middle_b.add_load_const(ops::Const::unary_unit_sum(), ExtensionSet::new())?; + let c = middle_b.add_load_const(ops::Const::unary_unit_sum())?; let [inw] = middle_b.input_wires_arr(); middle_b.finish_with_outputs(c, [inw])? }; @@ -398,8 +398,7 @@ mod test { #[test] fn test_dom_edge() -> Result<(), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let sum_tuple_const = - cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?; + let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?; let sum_variants = vec![type_row![]]; let mut entry_b = @@ -427,8 +426,7 @@ mod test { #[test] fn test_non_dom_edge() -> Result<(), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let sum_tuple_const = - cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?; + let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?; let sum_variants = vec![type_row![]]; let mut middle_b = cfg_builder .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 0238d14a4..1e3441968 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -242,7 +242,7 @@ mod test { "main", FunctionType::new(type_row![NAT], type_row![NAT]).into(), )?; - let tru_const = fbuild.add_constant(Const::true_val(), ExtensionSet::new())?; + let tru_const = fbuild.add_constant(Const::true_val())?; let _fdef = { let const_wire = fbuild.load_const(&tru_const)?; let [int] = fbuild.input_wires_arr(); diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 9ab71182b..bbcddade7 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -109,10 +109,7 @@ mod test { let build_result: Result = { let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?; let [i1] = loop_b.input_wires_arr(); - let const_wire = loop_b.add_load_const( - ConstUsize::new(1).into(), - ExtensionSet::singleton(&PRELUDE_ID), - )?; + let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?; let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?; loop_b.set_outputs(break_wire, [i1])?; @@ -148,8 +145,7 @@ mod test { fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?; let signature = loop_b.loop_signature()?.clone(); let const_val = Const::true_val(); - let const_wire = - loop_b.add_load_const(Const::true_val(), ExtensionSet::new())?; + let const_wire = loop_b.add_load_const(Const::true_val())?; let lift_node = loop_b.add_dataflow_op( ops::LeafOp::Lift { type_row: vec![const_val.const_type().clone()].into(), @@ -177,10 +173,7 @@ mod test { let mut branch_1 = conditional_b.case_builder(1)?; let [_b1] = branch_1.input_wires_arr(); - let wire = branch_1.add_load_const( - ConstUsize::new(2).into(), - ExtensionSet::singleton(&PRELUDE_ID), - )?; + let wire = branch_1.add_load_const(ConstUsize::new(2).into())?; let break_wire = branch_1.make_break(signature, [wire])?; branch_1.finish_with_outputs([break_wire])?; diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index d0640048a..c13a4183f 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -142,7 +142,7 @@ impl Rewrite for OutlineCfg { .unwrap(); let cfg = cfg.finish_sub_container().unwrap(); let unit_sum = new_block_bldr - .add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new()) + .add_constant(ops::Const::unary_unit_sum()) .unwrap(); let pred_wire = new_block_bldr.load_const(&unit_sum).unwrap(); new_block_bldr diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index dd478969f..4e3a49616 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -484,7 +484,7 @@ mod test { FunctionType::new_endo(just_list.clone()).with_extension_delta(&exset), )?; - let pred_const = cfg.add_constant(ops::Const::unary_unit_sum(), None)?; + let pred_const = cfg.add_constant(ops::Const::unary_unit_sum())?; let entry = single_node_block(&mut cfg, pop, &pred_const, true)?; let bb2 = single_node_block(&mut cfg, push, &pred_const, false)?; diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index 8b1545049..dc8e9add2 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -888,7 +888,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { let empty_list = Value::Extension { c: (Box::new(collections::ListValue::new(vec![])),), }; - let cst = def.add_load_const(Const::new(empty_list, list_of_var)?, just_colns)?; + let cst = def.add_load_const(Const::new(empty_list, list_of_var)?)?; let res = def.finish_hugr_with_outputs([cst], ®); assert_matches!( res.unwrap_err(), diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 2e57ece82..9f356a586 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -143,9 +143,7 @@ fn static_targets() { ) .unwrap(); - let c = dfg - .add_constant(ConstUsize::new(1).into(), ExtensionSet::new()) - .unwrap(); + let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap(); let load = dfg.load_const(&c).unwrap(); diff --git a/src/ops/constant.rs b/src/ops/constant.rs index cb4a38047..6db0006b3 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -161,25 +161,19 @@ mod test { type_row![], TypeRow::from(vec![pred_ty.clone()]), ))?; - let c = b.add_constant( - Const::tuple_sum( - 0, - Value::tuple([ - CustomTestValue(TypeBound::Eq, ExtensionSet::new()).into(), - serialized_float(5.1), - ]), - pred_rows.clone(), - )?, - ExtensionSet::new(), // ALAN remove given above? - )?; + let c = b.add_constant(Const::tuple_sum( + 0, + Value::tuple([ + CustomTestValue(TypeBound::Eq, ExtensionSet::new()).into(), + serialized_float(5.1), + ]), + pred_rows.clone(), + )?)?; let w = b.load_const(&c)?; b.finish_hugr_with_outputs([w], &test_registry()).unwrap(); let mut b = DFGBuilder::new(FunctionType::new(type_row![], TypeRow::from(vec![pred_ty])))?; - let c = b.add_constant( - Const::tuple_sum(1, Value::unit(), pred_rows)?, - ExtensionSet::new(), - )?; + let c = b.add_constant(Const::tuple_sum(1, Value::unit(), pred_rows)?)?; let w = b.load_const(&c)?; b.finish_hugr_with_outputs([w], &test_registry()).unwrap(); From 1da3a63ab51f98dd52e8fa93ff8949259292837f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 18:18:01 +0000 Subject: [PATCH 07/27] Fix replace::test::cfg (pending issue #388) --- src/hugr/rewrite/replace.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 4e3a49616..b2d6bf0c4 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -477,11 +477,12 @@ mod test { .unwrap() .into(); let just_list = TypeRow::from(vec![listy.clone()]); - let exset = ExtensionSet::singleton(&collections::EXTENSION_NAME); let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]); let mut cfg = CFGBuilder::new( - FunctionType::new_endo(just_list.clone()).with_extension_delta(&exset), + // One might expect an extension_delta of "collections" here, but push/pop + // have an empty delta themselves, pending https://github.com/CQCL/hugr/issues/388 + FunctionType::new_endo(just_list.clone()), )?; let pred_const = cfg.add_constant(ops::Const::unary_unit_sum())?; From 6c76919f6f904b284409a1bc4bfe4b59424a78a6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 10:18:25 +0000 Subject: [PATCH 08/27] clippy (cross-version issues) --- src/hugr/views/tests.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 9f356a586..2ae8c906c 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -3,10 +3,7 @@ use rstest::{fixture, rstest}; use crate::{ builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr}, - extension::{ - prelude::{PRELUDE_ID, QB_T}, - ExtensionSet, - }, + extension::prelude::QB_T, ops::handle::{DataflowOpID, NodeHandle}, type_row, types::FunctionType, @@ -135,7 +132,7 @@ fn value_types() { #[rustversion::since(1.75)] // uses impl in return position #[test] fn static_targets() { - use crate::extension::prelude::{ConstUsize, USIZE_T}; + use crate::extension::{ExtensionSet, prelude::{ConstUsize, USIZE_T, PRELUDE_ID}}; use itertools::Itertools; let mut dfg = DFGBuilder::new( FunctionType::new(type_row![], type_row![USIZE_T]) From 25852a3ba6eba8b5467f4df9e690988cb5517bff Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 10:20:02 +0000 Subject: [PATCH 09/27] ...and fmt --- src/hugr/views/tests.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 2ae8c906c..97fb50861 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -132,7 +132,10 @@ fn value_types() { #[rustversion::since(1.75)] // uses impl in return position #[test] fn static_targets() { - use crate::extension::{ExtensionSet, prelude::{ConstUsize, USIZE_T, PRELUDE_ID}}; + use crate::extension::{ + prelude::{ConstUsize, PRELUDE_ID, USIZE_T}, + ExtensionSet, + }; use itertools::Itertools; let mut dfg = DFGBuilder::new( FunctionType::new(type_row![], type_row![USIZE_T]) From 4856b17966dba3d84414be4d33fbdd7926295c2e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 23:17:03 +0000 Subject: [PATCH 10/27] Union OpDef's extension with that from SignatureFunc - in former not latter --- src/extension/op_def.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 5ce8b483f..56bdf3f64 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -223,7 +223,7 @@ impl SignatureFunc { /// /// This function will return an error if the type arguments are invalid or /// there is some error in type computation. - pub fn compute_signature( + fn compute_signature( &self, def: &OpDef, args: &[TypeArg], @@ -246,9 +246,6 @@ impl SignatureFunc { }; let res = pf.instantiate(args, exts)?; - // TODO bring this assert back once resource inference is done? - // https://github.com/CQCL/hugr/issues/388 - // debug_assert!(res.extension_reqs.contains(def.extension())); Ok(res) } } @@ -346,7 +343,11 @@ impl OpDef { args: &[TypeArg], exts: &ExtensionRegistry, ) -> Result { - self.signature_func.compute_signature(self, args, exts) + let mut functy = self.signature_func.compute_signature(self, args, exts)?; + functy.extension_reqs = functy + .extension_reqs + .union(&ExtensionSet::singleton(&self.extension)); + Ok(functy) } pub(crate) fn should_serialize_signature(&self) -> bool { From d07d1865894f5e5e6f737a002c80f965dfeb2047 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 23:34:58 +0000 Subject: [PATCH 11/27] Fix simple_linear (with lift) --- src/builder/circuit.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/builder/circuit.rs b/src/builder/circuit.rs index 054f839fa..a5b3c2ecc 100644 --- a/src/builder/circuit.rs +++ b/src/builder/circuit.rs @@ -137,22 +137,33 @@ mod test { test::{build_main, NAT, QB}, Dataflow, DataflowSubContainer, Wire, }, - extension::prelude::BOOL_T, + extension::{prelude::BOOL_T, ExtensionSet}, ops::{custom::OpaqueOp, LeafOp}, type_row, types::FunctionType, - utils::test_quantum_extension::{cx_gate, h_gate, measure}, + utils::test_quantum_extension::{cx_gate, h_gate, measure, EXTENSION_ID}, }; #[test] fn simple_linear() { let build_res = build_main( - FunctionType::new(type_row![QB, QB], type_row![QB, QB]).into(), + FunctionType::new_endo(type_row![QB, QB]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into(), |mut f_build| { - let wires = f_build.input_wires().collect(); + let mut wires: [Wire; 2] = f_build.input_wires_arr(); + [wires[1]] = f_build + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![QB].into(), + new_extension: EXTENSION_ID, + }, + [wires[1]], + )? + .outputs_arr(); let mut linear = CircuitBuilder { - wires, + wires: Vec::from(wires), builder: &mut f_build, }; From dee08254f35e70c972c8607e780ca8a2af70611d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 23:39:06 +0000 Subject: [PATCH 12/27] Fix nonlinear_and_outputs with another Lift --- src/builder/circuit.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/builder/circuit.rs b/src/builder/circuit.rs index a5b3c2ecc..eb47b519e 100644 --- a/src/builder/circuit.rs +++ b/src/builder/circuit.rs @@ -195,10 +195,20 @@ mod test { .into(), ); let build_res = build_main( - FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(), + FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into(), |mut f_build| { let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr(); - + let [angle] = f_build + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![NAT].into(), + new_extension: EXTENSION_ID, + }, + [angle], + )? + .outputs_arr(); let mut linear = f_build.as_circuit(vec![q0, q1]); let measure_out = linear From 5ae9281431d984cf5fda3db40309b16251ea5813 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 1 Dec 2023 23:52:05 +0000 Subject: [PATCH 13/27] Fix nested_identity + copy_insertion (many Lift's + parameterize over ExtensionSet) --- src/builder/dataflow.rs | 48 +++++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 4761696c7..1ca5418af 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -213,9 +213,9 @@ pub(crate) mod test { use crate::hugr::validate::InterGraphEdgeError; use crate::ops::{handle::NodeHandle, LeafOp, OpTag}; - use crate::std_extensions::logic::test::and_op; + use crate::std_extensions::logic::{self, test::and_op}; use crate::types::Type; - use crate::utils::test_quantum_extension::h_gate; + use crate::utils::test_quantum_extension::{self, h_gate}; use crate::{ builder::{ test::{n_identity, BIT, NAT, QB}, @@ -235,13 +235,25 @@ pub(crate) mod test { let _f_id = { let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(), + FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]) + .with_extension_delta(&ExtensionSet::singleton( + &test_quantum_extension::EXTENSION_ID, + )) + .into(), )?; let [int, qb] = func_builder.input_wires_arr(); let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?; - + let [int] = func_builder + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![NAT].into(), + new_extension: test_quantum_extension::EXTENSION_ID, + }, + [int], + )? + .outputs_arr(); let inner_builder = func_builder.dfg_builder( FunctionType::new(type_row![NAT], type_row![NAT]), None, @@ -260,7 +272,7 @@ pub(crate) mod test { } // Scaffolding for copy insertion tests - fn copy_scaffold(f: F, msg: &'static str) -> Result<(), BuildError> + fn copy_scaffold(f: F, delta: &ExtensionSet, msg: &'static str) -> Result<(), BuildError> where F: FnOnce(FunctionBuilder<&mut Hugr>) -> Result>, BuildError>, { @@ -269,7 +281,9 @@ pub(crate) mod test { let f_build = module_builder.define_function( "main", - FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).into(), + FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]) + .with_extension_delta(delta) + .into(), )?; f(f_build)?; @@ -287,15 +301,27 @@ pub(crate) mod test { let [b1] = f_build.input_wires_arr(); f_build.finish_with_outputs([b1, b1]) }, + &ExtensionSet::new(), "Copy input and output", )?; + let es = ExtensionSet::singleton(&logic::EXTENSION_ID); copy_scaffold( |mut f_build| { let [b1] = f_build.input_wires_arr(); let xor = f_build.add_dataflow_op(and_op(), [b1, b1])?; + let [b1] = f_build + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![BOOL_T].into(), + new_extension: logic::EXTENSION_ID, + }, + [b1], + )? + .outputs_arr(); f_build.finish_with_outputs([xor.out_wire(0), b1]) }, + &es, "Copy input and use with binary function", )?; @@ -303,9 +329,19 @@ pub(crate) mod test { |mut f_build| { let [b1] = f_build.input_wires_arr(); let xor1 = f_build.add_dataflow_op(and_op(), [b1, b1])?; + let [b1] = f_build + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![BOOL_T].into(), + new_extension: logic::EXTENSION_ID, + }, + [b1], + )? + .outputs_arr(); let xor2 = f_build.add_dataflow_op(and_op(), [b1, xor1.out_wire(0)])?; f_build.finish_with_outputs([xor2.out_wire(0), b1]) }, + &es, "Copy multiple times", )?; From 265df79ed6d522cdad71bf49a6fe14f2ee96f41f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 09:38:01 +0000 Subject: [PATCH 14/27] fix op_def.rs tests --- src/extension/op_def.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 56bdf3f64..7c68415bd 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -549,10 +549,10 @@ mod test { let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; assert_eq!( def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(FunctionType::new( - vec![USIZE_T; 3], - vec![Type::new_tuple(vec![USIZE_T; 3])] - )) + Ok( + FunctionType::new(vec![USIZE_T; 3], vec![Type::new_tuple(vec![USIZE_T; 3])]) + .with_extension_delta(&ExtensionSet::singleton(&EXT_ID)) + ) ); assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); @@ -562,10 +562,10 @@ mod test { let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; assert_eq!( def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(FunctionType::new( - tyvars.clone(), - vec![Type::new_tuple(tyvars)] - )) + Ok( + FunctionType::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) + .with_extension_delta(&ExtensionSet::singleton(&EXT_ID)) + ) ); def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Eq.into()]) .unwrap(); @@ -618,7 +618,8 @@ mod test { def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); assert_eq!( def.compute_signature(&args, &EMPTY_REG), - Ok(FunctionType::new_endo(vec![tv])) + Ok(FunctionType::new_endo(vec![tv]) + .with_extension_delta(&ExtensionSet::singleton(&EXT_ID))) ); Ok(()) } From b87c8848b9ed96a15738c80c72c44ee61165a7c3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 09:41:17 +0000 Subject: [PATCH 15/27] Fix search_variable_deps to handle solved Meta, and fix replace test --- src/extension/infer.rs | 3 ++- src/hugr/rewrite/replace.rs | 29 +++++++++++++++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 5d8176968..ac32db07d 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -648,11 +648,12 @@ impl UnificationContext { fn search_variable_deps(&self) -> HashSet { let mut seen = HashSet::new(); let mut new_variables: HashSet = self.variables.clone(); + let constraints_for_solved = HashSet::new(); while !new_variables.is_empty() { new_variables = new_variables .into_iter() .filter(|m| seen.insert(*m)) - .flat_map(|m| self.get_constraints(&m).unwrap()) + .flat_map(|m| self.get_constraints(&m).unwrap_or(&constraints_for_solved)) .map(|c| match c { Constraint::Plus(_, other) => self.resolve(*other), Constraint::Equal(other) => self.resolve(*other), diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index b2d6bf0c4..c34501b64 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -477,12 +477,11 @@ mod test { .unwrap() .into(); let just_list = TypeRow::from(vec![listy.clone()]); + let exset = ExtensionSet::singleton(&collections::EXTENSION_NAME); let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]); let mut cfg = CFGBuilder::new( - // One might expect an extension_delta of "collections" here, but push/pop - // have an empty delta themselves, pending https://github.com/CQCL/hugr/issues/388 - FunctionType::new_endo(just_list.clone()), + FunctionType::new_endo(just_list.clone()).with_extension_delta(&exset), )?; let pred_const = cfg.add_constant(ops::Const::unary_unit_sum())?; @@ -628,13 +627,31 @@ mod test { }, op_sig.input() ); - h.simple_entry_builder(op_sig.output, 1, op_sig.extension_reqs.clone())? + h.simple_entry_builder(op_sig.output.clone(), 1, op_sig.extension_reqs.clone())? } else { - h.simple_block_builder(op_sig, 1)? + h.simple_block_builder(op_sig.clone(), 1)? }; let op: OpType = op.into(); let op = bb.add_dataflow_op(op, bb.input_wires())?; - let load_pred = bb.load_const(pred_const)?; + let mut load_pred = bb.load_const(pred_const)?; + let const_ty = bb + .hugr() + .get_optype(pred_const.node()) + .as_const() + .unwrap() + .const_type() + .clone(); + for e in op_sig.extension_reqs.iter() { + [load_pred] = bb + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![const_ty.clone()].into(), + new_extension: e.clone(), + }, + [load_pred], + )? + .outputs_arr(); + } bb.finish_with_outputs(load_pred, op.outputs()) } From 97f9b876fab0f6bffb0f5e2150807ba3f4f5cef8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 10:09:38 +0000 Subject: [PATCH 16/27] Fix (simple_)replacement tests --- src/hugr/rewrite/simple_replace.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 2f14dee3e..469017d86 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -221,7 +221,7 @@ pub(in crate::hugr::rewrite) mod test { HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::BOOL_T; - use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; + use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; use crate::ops::dataflow::DataflowOpTrait; @@ -230,7 +230,7 @@ pub(in crate::hugr::rewrite) mod test { use crate::std_extensions::logic::test::and_op; use crate::type_row; use crate::types::{FunctionType, Type}; - use crate::utils::test_quantum_extension::{cx_gate, h_gate}; + use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; @@ -249,9 +249,12 @@ pub(in crate::hugr::rewrite) mod test { fn make_hugr() -> Result { let mut module_builder = ModuleBuilder::new(); let _f_id = { + let delta = ExtensionSet::singleton(&EXTENSION_ID); let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![QB, QB, QB], type_row![QB, QB, QB]).into(), + FunctionType::new_endo(type_row![QB, QB, QB]) + .with_extension_delta(&delta) + .into(), )?; let [qb0, qb1, qb2] = func_builder.input_wires_arr(); @@ -259,7 +262,7 @@ pub(in crate::hugr::rewrite) mod test { let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?; let mut inner_builder = func_builder.dfg_builder( - FunctionType::new(type_row![QB, QB], type_row![QB, QB]), + FunctionType::new_endo(type_row![QB, QB]).with_extension_delta(&delta), None, [qb0, qb1], )?; From 9f51a9bc9507de1fb0ceee6b869beecbfb2680ad Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 10:13:15 +0000 Subject: [PATCH 17/27] fix test_ext_edge --- src/hugr/validate/test.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index dc8e9add2..7236504c7 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -16,7 +16,7 @@ use crate::macros::const_extension_ids; use crate::ops::dataflow::IOTrait; use crate::ops::{self, Const, LeafOp, OpType}; use crate::std_extensions::logic::test::{and_op, or_op}; -use crate::std_extensions::logic::{self, NotOp}; +use crate::std_extensions::logic::{self, NotOp, EXTENSION_ID}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow}; use crate::values::Value; @@ -360,17 +360,18 @@ fn cfg_children_restrictions() { #[test] fn test_ext_edge() -> Result<(), HugrError> { - let mut h = closed_dfg_root_hugr(FunctionType::new( - type_row![BOOL_T, BOOL_T], - type_row![BOOL_T], - )); + let delta = ExtensionSet::singleton(&EXTENSION_ID); + let mut h = closed_dfg_root_hugr( + FunctionType::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T]) + .with_extension_delta(&delta), + ); let [input, output] = h.get_io(h.root()).unwrap(); // Nested DFG BOOL_T -> BOOL_T let sub_dfg = h.add_node_with_parent( h.root(), ops::DFG { - signature: FunctionType::new_endo(type_row![BOOL_T]), + signature: FunctionType::new_endo(type_row![BOOL_T]).with_extension_delta(&delta), }, )?; // this Xor has its 2nd input unconnected From bc585f40bc7911047e70bb8780ef1532284fb1e7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 12:21:26 +0000 Subject: [PATCH 18/27] fix test_local_const --- src/hugr/validate/test.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index 7236504c7..c20c8a767 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -412,7 +412,10 @@ const_extension_ids! { #[test] fn test_local_const() -> Result<(), HugrError> { - let mut h = closed_dfg_root_hugr(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])); + let mut h = closed_dfg_root_hugr( + FunctionType::new_endo(type_row![BOOL_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)), + ); let [input, output] = h.get_io(h.root()).unwrap(); let and = h.add_node_with_parent(h.root(), and_op())?; h.connect(input, 0, and, 0)?; From 472150549dafecf944c05735612b25af52a42cbf Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 12:32:53 +0000 Subject: [PATCH 19/27] fix full_region/flat_region tests --- src/hugr/views/descendants.rs | 30 +++++++++++++++++++++++------- src/hugr/views/sibling.rs | 2 +- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/hugr/views/descendants.rs b/src/hugr/views/descendants.rs index f0c515457..8b89e9836 100644 --- a/src/hugr/views/descendants.rs +++ b/src/hugr/views/descendants.rs @@ -201,10 +201,11 @@ where pub(super) mod test { use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, - ops::handle::NodeHandle, + extension::ExtensionSet, + ops::{handle::NodeHandle, LeafOp}, type_row, types::{FunctionType, Type}, - utils::test_quantum_extension::h_gate, + utils::test_quantum_extension::{h_gate, EXTENSION_ID}, }; use super::*; @@ -222,16 +223,27 @@ pub(super) mod test { let (f_id, inner_id) = { let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(), + FunctionType::new_endo(type_row![NAT, QB]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into(), )?; let [int, qb] = func_builder.input_wires_arr(); let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?; + let [int] = func_builder + .add_dataflow_op( + LeafOp::Lift { + type_row: type_row![NAT], + new_extension: EXTENSION_ID, + }, + [int], + )? + .outputs_arr(); let inner_id = { let inner_builder = func_builder.dfg_builder( - FunctionType::new(type_row![NAT], type_row![NAT]), + FunctionType::new_endo(type_row![NAT]), None, [int], )?; @@ -249,11 +261,11 @@ pub(super) mod test { #[test] fn full_region() -> Result<(), Box> { - let (hugr, def, inner) = make_module_hgr()?; + let (hugr, def, inner) = make_module_hgr().unwrap(); let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; - assert_eq!(region.node_count(), 7); + assert_eq!(region.node_count(), 8); assert!(region.nodes().all(|n| n == def || hugr.get_parent(n) == Some(def) || hugr.get_parent(n) == Some(inner))); @@ -261,7 +273,11 @@ pub(super) mod test { assert_eq!( region.get_function_type(), - Some(FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into()) + Some( + FunctionType::new_endo(type_row![NAT, QB]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into() + ) ); let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; assert_eq!( diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index bcc122361..4bf6c7ac2 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -387,7 +387,7 @@ mod test { let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?; - assert_eq!(region.node_count(), 5); + assert_eq!(region.node_count(), 6); assert!(region .nodes() .all(|n| n == def || hugr.get_parent(n) == Some(def))); From 91c1b226a35053ba519c14a53f34298e46319d13 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 12:36:40 +0000 Subject: [PATCH 20/27] Fix test_binary_signatures --- src/std_extensions/arithmetic/int_ops.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index cd9221deb..7c376eb3f 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -468,9 +468,11 @@ mod test { PRELUDE.to_owned(), ]) .unwrap(); + let delta = ExtensionSet::singleton(&EXTENSION_ID); assert_eq!( iwiden_s.compute_signature(&[ta(3), ta(4)], ®).unwrap(), - FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) + FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))]) + .with_extension_delta(&delta) ); let iwiden_u = EXTENSION.get_op("iwiden_u").unwrap(); @@ -482,7 +484,8 @@ mod test { assert_eq!( inarrow_s.compute_signature(&[ta(2), ta(1)], ®).unwrap(), - FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))],) + FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))]) + .with_extension_delta(&delta) ); let inarrow_u = EXTENSION.get_op("inarrow_u").unwrap(); From 3a281c49984bc470591d00c0210728d0b805c222 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 13:58:18 +0000 Subject: [PATCH 21/27] Fix dataflow_ports_only --- src/hugr/views/tests.rs | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 97fb50861..145fe0bc7 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -3,7 +3,7 @@ use rstest::{fixture, rstest}; use crate::{ builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr}, - extension::prelude::QB_T, + extension::{prelude::QB_T, ExtensionSet}, ops::handle::{DataflowOpID, NodeHandle}, type_row, types::FunctionType, @@ -163,10 +163,14 @@ fn test_dataflow_ports_only() { use crate::builder::DataflowSubContainer; use crate::extension::{prelude::BOOL_T, PRELUDE_REGISTRY}; use crate::hugr::views::PortIterator; - use crate::std_extensions::logic::NotOp; + use crate::std_extensions::logic::{NotOp, EXTENSION_ID}; use itertools::Itertools; - let mut dfg = DFGBuilder::new(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])).unwrap(); + let mut dfg = DFGBuilder::new( + FunctionType::new_endo(type_row![BOOL_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)), + ) + .unwrap(); let local_and = { let local_and = dfg .define_function( @@ -189,6 +193,25 @@ fn test_dataflow_ports_only() { ) .unwrap(); dfg.add_other_wire(not.node(), call.node()).unwrap(); + + // As temporary workaround for https://github.com/CQCL/hugr/issues/695 + // We force the input-extensions of the FuncDefn node to include the logic + // extension, so the static edge from the FuncDefn to the call has the same + // extensions as the result of the "not". + { + let nt = dfg.hugr_mut().op_types.get_mut(local_and.node().pg_index()); + assert_eq!(nt.input_extensions, Some(ExtensionSet::new())); + nt.input_extensions = Some(ExtensionSet::singleton(&EXTENSION_ID)); + } + // Note that presently the builder sets too many input-exts that could be + // left to the inference (https://github.com/CQCL/hugr/issues/702) hence we + // must manually change these too, although we can let inference deal with them + for node in dfg.hugr().get_io(local_and.node()).unwrap() { + let nt = dfg.hugr_mut().op_types.get_mut(node.pg_index()); + assert_eq!(nt.input_extensions, Some(ExtensionSet::new())); + nt.input_extensions = None; + } + let h = dfg .finish_hugr_with_outputs(not.outputs(), &PRELUDE_REGISTRY) .unwrap(); From 89cea89268b61e45c8826fd6db335aa1eb659936 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 2 Dec 2023 00:57:21 +0000 Subject: [PATCH 22/27] sibling_subgraph: add lift nodes to test, pass extension delta to construct_simple_replacement. Ugh... --- src/hugr/views/sibling_subgraph.rs | 75 ++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 461083378..116b429bf 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -287,8 +287,9 @@ impl SiblingSubgraph { &self.outputs } - /// The signature of the subgraph. + /// The signature of the subgraph, excluding any extension delta pub fn signature(&self, hugr: &impl HugrView) -> FunctionType { + // We cannot calculate the delta from just the extensions at the input and output! let input = self .inputs .iter() @@ -344,7 +345,12 @@ impl SiblingSubgraph { let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else { return Err(InvalidReplacement::InvalidDataflowParent); }; - if dfg_optype.dataflow_signature() != Some(self.signature(hugr)) { + if !dfg_optype.dataflow_signature().is_some_and(|rep_sig| { + rep_sig + == self + .signature(hugr) + .with_extension_delta(&rep_sig.extension_reqs) + }) { return Err(InvalidReplacement::InvalidSignature); } @@ -408,8 +414,15 @@ impl SiblingSubgraph { &self, hugr: &impl HugrView, name: impl Into, + extension_delta: &crate::extension::ExtensionSet, ) -> Result { - let mut builder = FunctionBuilder::new(name, self.signature(hugr).into()).unwrap(); + let mut builder = FunctionBuilder::new( + name, + self.signature(hugr) + .with_extension_delta(extension_delta) + .into(), + ) + .unwrap(); // Take the unfinished Hugr from the builder, to avoid unnecessary // validation checks that require connecting the inputs and outputs. let mut extracted = mem::take(builder.hugr_mut()); @@ -675,8 +688,9 @@ mod tests { use cool_asserts::assert_matches; - use crate::extension::PRELUDE_REGISTRY; - use crate::utils::test_quantum_extension::cx_gate; + use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; + use crate::ops::LeafOp; + use crate::utils::test_quantum_extension::{cx_gate, EXTENSION_ID}; use crate::{ builder::{ BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, @@ -723,17 +737,26 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]).into(), + FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let [w0, w1, w2] = dfg.input_wires_arr(); let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr(); + let [w2] = dfg + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![QB_T].into(), + new_extension: EXTENSION_ID, + }, + [w2], + )? + .outputs_arr(); dfg.finish_with_outputs([w0, w1, w2])? }; - let hugr = mod_builder - .finish_prelude_hugr() - .map_err(|e| -> BuildError { e.into() })?; + let hugr = mod_builder.finish_prelude_hugr()?; Ok((hugr, func_id.node())) } @@ -797,18 +820,32 @@ mod tests { let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; let empty_dfg = { - let builder = DFGBuilder::new(FunctionType::new_endo(type_row![QB_T, QB_T])).unwrap(); + let mut builder = DFGBuilder::new( + FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)), + ) + .unwrap(); let inputs = builder.input_wires(); - builder.finish_prelude_hugr_with_outputs(inputs).unwrap() + let lifted = builder + .add_dataflow_op( + LeafOp::Lift { + type_row: type_row![QB_T, QB_T, QB_T], + new_extension: EXTENSION_ID, + }, + inputs, + ) + .unwrap(); + builder.set_outputs(lifted.outputs()).unwrap(); + builder.finish_prelude_hugr().unwrap() }; let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap(); - assert_eq!(rep.subgraph().nodes().len(), 1); + assert_eq!(rep.subgraph().nodes().len(), 2); - assert_eq!(hugr.node_count(), 5); // Module + Def + In + CX + Out + assert_eq!(hugr.node_count(), 6); // Module + Def + In + CX + Lift + Out hugr.apply_rewrite(rep).unwrap(); - assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out + assert_eq!(hugr.node_count(), 5); // Module + Def + In + Lift + Out Ok(()) } @@ -818,11 +855,10 @@ mod tests { let (hugr, dfg) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, dfg).unwrap(); let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; - // The identity wire on the third qubit is ignored, so the subgraph's signature only contains - // the first two qubits. + // The third wire is included because of the "Lift" node assert_eq!( sub.signature(&func), - FunctionType::new_endo(type_row![QB_T, QB_T]) + FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]) ); Ok(()) } @@ -855,7 +891,7 @@ mod tests { .unwrap() .nodes() .len(), - 1 + 2 // Include Lift node ) } @@ -960,7 +996,8 @@ mod tests { let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); - let extracted = subgraph.extract_subgraph(&hugr, "region")?; + let extracted = + subgraph.extract_subgraph(&hugr, "region", &ExtensionSet::singleton(&EXTENSION_ID))?; extracted.validate(&PRELUDE_REGISTRY).unwrap(); From 4c4bb874f3dd26b89578594321b75352891cc354 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 12 Dec 2023 12:58:51 +0000 Subject: [PATCH 23/27] Add ExtensionSet::union_over --- src/extension.rs | 10 ++++++++++ src/extension/infer.rs | 2 +- src/std_extensions/collections.rs | 6 ++---- src/values.rs | 5 +---- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/extension.rs b/src/extension.rs index f18834baa..95b0474ea 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -387,6 +387,16 @@ impl ExtensionSet { self } + /// Returns the union of an arbitrary collection of [ExtensionSet]s + pub fn union_over(sets: impl IntoIterator) -> Self { + // `union` clones the receiver, which we do not need to do here + let mut res = ExtensionSet::new(); + for s in sets { + res.0.extend(s.0) + } + res + } + /// The things in other which are in not in self pub fn missing_from(&self, other: &Self) -> Self { ExtensionSet::from_iter(other.0.difference(&self.0).cloned()) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 5d8176968..86a082c81 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -699,7 +699,7 @@ impl UnificationContext { }); let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip(); - let solution = rs.iter().fold(ExtensionSet::new(), ExtensionSet::union); + let solution = ExtensionSet::union_over(rs); let unresolved_metas = other_ms .into_iter() .filter(|other_m| m != *other_m) diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index fa42b1a78..a78f4793a 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -68,10 +68,8 @@ impl CustomConst for ListValue { } fn extension_reqs(&self) -> ExtensionSet { - self.0 - .iter() - .map(Value::extension_reqs) - .fold(ExtensionSet::singleton(&EXTENSION_NAME), |a, b| a.union(&b)) + ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) + .union(&ExtensionSet::singleton(&EXTENSION_NAME)) } } const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; diff --git a/src/values.rs b/src/values.rs index f6215760a..17d173a00 100644 --- a/src/values.rs +++ b/src/values.rs @@ -122,10 +122,7 @@ impl Value { match self { Value::Extension { c } => c.0.extension_reqs().clone(), Value::Function { .. } => ExtensionSet::new(), // no extensions reqd to load Hugr (only to run) - Value::Tuple { vs } => vs - .iter() - .map(Value::extension_reqs) - .fold(ExtensionSet::new(), |a, b| a.union(&b)), + Value::Tuple { vs } => ExtensionSet::union_over(vs.iter().map(Value::extension_reqs)), Value::Sum { value, .. } => value.extension_reqs(), } } From b8de62f36d667e5afd8ad0d96e273bc99d837430 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 12 Dec 2023 12:59:32 +0000 Subject: [PATCH 24/27] driveby turn lambda into ExtensionSet::union --- src/extension/infer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 86a082c81..84e22b65a 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -727,7 +727,7 @@ impl UnificationContext { Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)), Constraint::Equal(_) => None, }) - .fold(ExtensionSet::new(), |a, b| a.union(b)); + .fold(ExtensionSet::new(), ExtensionSet::union); for m in cc.iter() { self.add_solution(*m, combined_solution.clone()); From 66bff433b86dda17b2cbe0e9f38cc217ff059bc2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 12 Dec 2023 14:31:17 +0000 Subject: [PATCH 25/27] Add extra set in 'impl SignatureFunc' rather than 'impl OpDef' --- src/extension/op_def.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 5f8284e17..20aea61b5 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -245,7 +245,10 @@ impl SignatureFunc { } }; - let res = pf.instantiate(args, exts)?; + let mut res = pf.instantiate(args, exts)?; + res.extension_reqs = res + .extension_reqs + .union(&ExtensionSet::singleton(&def.extension())); Ok(res) } } @@ -343,11 +346,7 @@ impl OpDef { args: &[TypeArg], exts: &ExtensionRegistry, ) -> Result { - let mut functy = self.signature_func.compute_signature(self, args, exts)?; - functy.extension_reqs = functy - .extension_reqs - .union(&ExtensionSet::singleton(&self.extension)); - Ok(functy) + self.signature_func.compute_signature(self, args, exts) } /// Fallibly returns a Hugr that may replace an instance of this OpDef From fe8eaf106098b98626ec878b6b40f9880e54ac55 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 12 Dec 2023 14:38:48 +0000 Subject: [PATCH 26/27] a bit of clippy --- src/extension/op_def.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 20aea61b5..3e39f6523 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -248,7 +248,7 @@ impl SignatureFunc { let mut res = pf.instantiate(args, exts)?; res.extension_reqs = res .extension_reqs - .union(&ExtensionSet::singleton(&def.extension())); + .union(&ExtensionSet::singleton(def.extension())); Ok(res) } } From 22c44584e43f5b85682207c8c911ad256b25c263 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 12 Dec 2023 14:40:48 +0000 Subject: [PATCH 27/27] Fix cross-version unused-imports in views/tests.rs --- src/hugr/views/tests.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 145fe0bc7..43ccfbaee 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -3,7 +3,7 @@ use rstest::{fixture, rstest}; use crate::{ builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr}, - extension::{prelude::QB_T, ExtensionSet}, + extension::prelude::QB_T, ops::handle::{DataflowOpID, NodeHandle}, type_row, types::FunctionType, @@ -161,7 +161,7 @@ fn static_targets() { #[test] fn test_dataflow_ports_only() { use crate::builder::DataflowSubContainer; - use crate::extension::{prelude::BOOL_T, PRELUDE_REGISTRY}; + use crate::extension::{prelude::BOOL_T, ExtensionSet, PRELUDE_REGISTRY}; use crate::hugr::views::PortIterator; use crate::std_extensions::logic::{NotOp, EXTENSION_ID}; use itertools::Itertools;