From 41c9c144a411e9b8b5b058de5eae38b9a30e5247 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 24 Oct 2023 15:13:24 +0100 Subject: [PATCH] Change many add_node+open_extensions to use add_op --- src/builder.rs | 12 ++-- src/builder/conditional.rs | 2 +- src/extension/infer.rs | 140 +++++++++++++++++-------------------- src/hugr.rs | 8 +-- src/hugr/hugrmut.rs | 2 +- 5 files changed, 76 insertions(+), 88 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index c8c0cdc55..51b90f8a2 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -148,18 +148,18 @@ pub(crate) mod test { let mut hugr = Hugr::new(NodeType::pure(ops::DFG { signature: signature.clone(), })); - hugr.add_node_with_parent( + hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::Input { + ops::Input { types: signature.input, - }), + }, ) .unwrap(); - hugr.add_node_with_parent( + hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::Output { + ops::Output { types: signature.output, - }), + }, ) .unwrap(); hugr diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 7b196d3a9..da8808eea 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -128,7 +128,7 @@ impl + AsRef> ConditionalBuilder { if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() { self.hugr_mut().add_op_before(sibling_node, case_op)? } else { - self.add_child_node(NodeType::open_extensions(case_op))? + self.add_child_op(case_op)? }; self.case_nodes[case] = Some(case_node); diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 236d3aeb4..88efb5cfe 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -730,11 +730,11 @@ mod test { let root_node = NodeType::open_extensions(op); let mut hugr = Hugr::new(root_node); - let input = NodeType::open_extensions(ops::Input::new(type_row![NAT, NAT])); - let output = NodeType::open_extensions(ops::Output::new(type_row![NAT])); + let input = ops::Input::new(type_row![NAT, NAT]); + let output = ops::Output::new(type_row![NAT]); - let input = hugr.add_node_with_parent(hugr.root(), input)?; - let output = hugr.add_node_with_parent(hugr.root(), output)?; + let input = hugr.add_op_with_parent(hugr.root(), input)?; + let output = hugr.add_op_with_parent(hugr.root(), output)?; assert_matches!(hugr.get_io(hugr.root()), Some(_)); @@ -750,29 +750,29 @@ mod test { let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]) .with_extension_delta(&ExtensionSet::singleton(&C)); - let add_a = hugr.add_node_with_parent( + let add_a = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: add_a_sig, - }), + }, )?; - let add_b = hugr.add_node_with_parent( + let add_b = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: add_b_sig, - }), + }, )?; - let add_ab = hugr.add_node_with_parent( + let add_ab = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: add_ab_sig, - }), + }, )?; - let mult_c = hugr.add_node_with_parent( + let mult_c = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: mult_c_sig, - }), + }, )?; hugr.connect(input, 0, add_a, 0)?; @@ -906,29 +906,26 @@ mod test { let [input, output] = hugr.get_io(hugr.root()).unwrap(); let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); - let add_r = hugr.add_node_with_parent( + let add_r = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: add_r_sig, - }), + }, )?; // Dangling thingy let src_sig = FunctionType::new(type_row![], type_row![NAT]) .with_extension_delta(&ExtensionSet::new()); - let src = hugr.add_node_with_parent( - hugr.root(), - NodeType::open_extensions(ops::DFG { signature: src_sig }), - )?; + let src = hugr.add_op_with_parent(hugr.root(), ops::DFG { signature: src_sig })?; let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]); // Mult has open extension requirements, which we should solve to be "R" - let mult = hugr.add_node_with_parent( + let mult = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: mult_sig, - }), + }, )?; hugr.connect(input, 0, add_r, 0)?; @@ -988,18 +985,18 @@ mod test { ) -> Result<[Node; 3], Box> { let op: OpType = op.into(); - let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?; - let input = hugr.add_node_with_parent( + let node = hugr.add_op_with_parent(parent, op)?; + let input = hugr.add_op_with_parent( node, - NodeType::open_extensions(ops::Input { + ops::Input { types: op_sig.input, - }), + }, )?; - let output = hugr.add_node_with_parent( + let output = hugr.add_op_with_parent( node, - NodeType::open_extensions(ops::Output { + ops::Output { types: op_sig.output, - }), + }, )?; Ok([node, input, output]) } @@ -1020,20 +1017,20 @@ mod test { Into::::into(op).signature(), )?; - let lift1 = hugr.add_node_with_parent( + let lift1 = hugr.add_op_with_parent( case, - NodeType::open_extensions(ops::LeafOp::Lift { + ops::LeafOp::Lift { type_row: type_row![NAT], new_extension: first_ext, - }), + }, )?; - let lift2 = hugr.add_node_with_parent( + let lift2 = hugr.add_op_with_parent( case, - NodeType::open_extensions(ops::LeafOp::Lift { + ops::LeafOp::Lift { type_row: type_row![NAT], new_extension: second_ext, - }), + }, )?; hugr.connect(case_in, 0, lift1, 0)?; @@ -1098,17 +1095,17 @@ mod test { })); let root = hugr.root(); - let input = hugr.add_node_with_parent( + let input = hugr.add_op_with_parent( root, - NodeType::open_extensions(ops::Input { + ops::Input { types: type_row![NAT], - }), + }, )?; - let output = hugr.add_node_with_parent( + let output = hugr.add_op_with_parent( root, - NodeType::open_extensions(ops::Output { + ops::Output { types: type_row![NAT], - }), + }, )?; // Make identical dataflow nodes which add extension requirement "A" or "B" @@ -1129,12 +1126,12 @@ mod test { .unwrap(); let lift = hugr - .add_node_with_parent( + .add_op_with_parent( node, - NodeType::open_extensions(ops::LeafOp::Lift { + ops::LeafOp::Lift { type_row: type_row![NAT], new_extension: ext, - }), + }, ) .unwrap(); @@ -1181,7 +1178,7 @@ mod test { let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?; - let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?; + let dfg = hugr.add_op_with_parent(bb, op)?; hugr.connect(bb_in, 0, dfg, 0)?; hugr.connect(dfg, 0, bb_out, 0)?; @@ -1213,23 +1210,20 @@ mod test { extension_delta: entry_extensions, }; - let exit = hugr.add_node_with_parent( + let exit = hugr.add_op_with_parent( root, - NodeType::open_extensions(ops::BasicBlock::Exit { + ops::BasicBlock::Exit { cfg_outputs: exit_types.into(), - }), + }, )?; - let entry = hugr.add_op_before(exit,dfb)?; - let entry_in = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(ops::Input { types: inputs }), - )?; - let entry_out = hugr.add_node_with_parent( + let entry = hugr.add_op_before(exit, dfb)?; + let entry_in = hugr.add_op_with_parent(entry, ops::Input { types: inputs })?; + let entry_out = hugr.add_op_with_parent( entry, - NodeType::open_extensions(ops::Output { + ops::Output { types: vec![entry_tuple_sum].into(), - }), + }, )?; Ok(([entry, entry_in, entry_out], exit)) @@ -1280,12 +1274,12 @@ mod test { type_row![NAT], )?; - let mkpred = hugr.add_node_with_parent( + let mkpred = hugr.add_op_with_parent( entry, - NodeType::open_extensions(make_opaque( + make_opaque( A, FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a), - )), + ), )?; // Internal wiring for DFGs @@ -1376,12 +1370,9 @@ mod test { type_row![NAT], )?; - let entry_mid = hugr.add_node_with_parent( + let entry_mid = hugr.add_op_with_parent( entry, - NodeType::open_extensions(make_opaque( - UNKNOWN_EXTENSION, - FunctionType::new(vec![NAT], twoway(NAT)), - )), + make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], twoway(NAT))), )?; hugr.connect(entry_in, 0, entry_mid, 0)?; @@ -1465,12 +1456,12 @@ mod test { type_row![NAT], )?; - let entry_dfg = hugr.add_node_with_parent( + let entry_dfg = hugr.add_op_with_parent( entry, - NodeType::open_extensions(make_opaque( + make_opaque( UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext), - )), + ), )?; hugr.connect(entry_in, 0, entry_dfg, 0)?; @@ -1546,12 +1537,9 @@ mod test { type_row![NAT], )?; - let entry_mid = hugr.add_node_with_parent( + let entry_mid = hugr.add_op_with_parent( entry, - NodeType::open_extensions(make_opaque( - UNKNOWN_EXTENSION, - FunctionType::new(vec![NAT], oneway(NAT)), - )), + make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT))), )?; hugr.connect(entry_in, 0, entry_mid, 0)?; diff --git a/src/hugr.rs b/src/hugr.rs index 322a123c6..903706e68 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -609,7 +609,7 @@ impl From for PyErr { #[cfg(test)] mod test { - use super::{Hugr, HugrView, NodeType}; + use super::{Hugr, HugrView}; use crate::builder::test::closed_dfg_root_hugr; use crate::extension::ExtensionSet; use crate::hugr::HugrMut; @@ -645,12 +645,12 @@ mod test { FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r), ); let [input, output] = hugr.get_io(hugr.root()).unwrap(); - let lift = hugr.add_node_with_parent( + let lift = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::LeafOp::Lift { + ops::LeafOp::Lift { type_row: type_row![BIT], new_extension: "R".try_into().unwrap(), - }), + }, )?; hugr.connect(input, 0, lift, 0)?; hugr.connect(lift, 0, output, 0)?; diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 360011bb5..5a5d7f10e 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -206,7 +206,7 @@ impl + AsMut> HugrMut for T { } fn add_op_before(&mut self, sibling: Node, op: impl Into) -> Result { - let node = self.as_mut().add_node(NodeType::open_extensions(op)); + let node = self.as_mut().add_op(op); self.as_mut() .hierarchy .insert_before(node.index, sibling.index)?;