Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ops require their own extensions #734

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f1e1e6f
Add Value::extension_reqs (not used yet)
acl-cqc Dec 1, 2023
cdc5785
Add OpTrait::extension_delta, non-empty for Const or DataflowOp
acl-cqc Dec 1, 2023
9e1b3eb
Fix test_conditional_inference
acl-cqc Dec 1, 2023
b47ad17
Fix test_tuple_sum by build_traits.rs: when adding load_const, do not…
acl-cqc Dec 1, 2023
bb66a8b
Fix static_targets
acl-cqc Dec 1, 2023
35e1787
add_constant drop ExtensionSet parameter - always empty
acl-cqc Dec 1, 2023
1da3a63
Fix replace::test::cfg (pending issue #388)
acl-cqc Dec 1, 2023
6c76919
clippy (cross-version issues)
acl-cqc Dec 2, 2023
25852a3
...and fmt
acl-cqc Dec 2, 2023
4856b17
Union OpDef's extension with that from SignatureFunc - in former not …
acl-cqc Dec 1, 2023
d07d186
Fix simple_linear (with lift)
acl-cqc Dec 1, 2023
dee0825
Fix nonlinear_and_outputs with another Lift
acl-cqc Dec 1, 2023
5ae9281
Fix nested_identity + copy_insertion (many Lift's + parameterize over…
acl-cqc Dec 1, 2023
265df79
fix op_def.rs tests
acl-cqc Dec 2, 2023
b87c884
Fix search_variable_deps to handle solved Meta, and fix replace test
acl-cqc Dec 2, 2023
97f9b87
Fix (simple_)replacement tests
acl-cqc Dec 2, 2023
9f51a9b
fix test_ext_edge
acl-cqc Dec 2, 2023
bc585f4
fix test_local_const
acl-cqc Dec 2, 2023
4721505
fix full_region/flat_region tests
acl-cqc Dec 2, 2023
91c1b22
Fix test_binary_signatures
acl-cqc Dec 2, 2023
3a281c4
Fix dataflow_ports_only
acl-cqc Dec 2, 2023
89cea89
sibling_subgraph: add lift nodes to test, pass extension delta to con…
acl-cqc Dec 2, 2023
6f0b215
Merge remote-tracking branch 'origin/main' into HEAD
acl-cqc Dec 12, 2023
4c4bb87
Add ExtensionSet::union_over
acl-cqc Dec 12, 2023
b8de62f
driveby turn lambda into ExtensionSet::union
acl-cqc Dec 12, 2023
b407523
Merge remote-tracking branch 'origin/main' into HEAD
acl-cqc Dec 12, 2023
7f59d6f
Merge commit 'b4075239' into fix/ops_require_ext
acl-cqc Dec 12, 2023
66bff43
Add extra set in 'impl SignatureFunc' rather than 'impl OpDef'
acl-cqc Dec 12, 2023
fe8eaf1
a bit of clippy
acl-cqc Dec 12, 2023
22c4458
Fix cross-version unused-imports in views/tests.rs
acl-cqc Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ pub(crate) mod test {
// \-> right -/ \-<--<-/
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down Expand Up @@ -887,8 +887,8 @@ pub(crate) mod test {
separate: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?,
Expand Down Expand Up @@ -929,8 +929,8 @@ pub(crate) mod test {
cfg_builder: &mut CFGBuilder<T>,
separate_headers: bool,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down
28 changes: 8 additions & 20 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,8 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(
&mut self,
constant: ops::Const,
extensions: impl Into<Option<ExtensionSet>>,
) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, extensions.into()))?;
fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?;

Ok(const_n.into())
}
Expand Down Expand Up @@ -356,20 +352,16 @@ pub trait Dataflow: Container {
fn load_const(&mut self, cid: &ConstID) -> Result<Wire, BuildError> {
let const_node = cid.node();
let nodetype = self.hugr().get_nodetype(const_node);
let input_extensions = nodetype.input_extensions().cloned();
let op: ops::Const = nodetype
.op()
.clone()
.try_into()
.expect("ConstID does not refer to Const op.");

let load_n = self.add_dataflow_node(
NodeType::new(
ops::LoadConstant {
datatype: op.const_type().clone(),
},
input_extensions,
),
let load_n = self.add_dataflow_op(
ops::LoadConstant {
datatype: op.const_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, OutgoingPort::from(0))],
)?;
Expand All @@ -382,12 +374,8 @@ pub trait Dataflow: Container {
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(
&mut self,
constant: ops::Const,
extensions: ExtensionSet,
) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant, extensions)?;
fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
self.load_const(&cid)
}

Expand Down
8 changes: 3 additions & 5 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ mod test {
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.add_load_const(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let c = middle_b.add_load_const(ops::Const::unary_unit_sum())?;
let [inw] = middle_b.input_wires_arr();
middle_b.finish_with_outputs(c, [inw])?
};
Expand All @@ -398,8 +398,7 @@ mod test {
#[test]
fn test_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?;
let sum_variants = vec![type_row![]];

let mut entry_b =
Expand Down Expand Up @@ -427,8 +426,7 @@ mod test {
#[test]
fn test_non_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?;
let sum_variants = vec![type_row![]];
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
Expand Down
35 changes: 28 additions & 7 deletions src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,33 @@ mod test {
test::{build_main, NAT, QB},
Dataflow, DataflowSubContainer, Wire,
},
extension::prelude::BOOL_T,
extension::{prelude::BOOL_T, ExtensionSet},
ops::{custom::OpaqueOp, LeafOp},
type_row,
types::FunctionType,
utils::test_quantum_extension::{cx_gate, h_gate, measure},
utils::test_quantum_extension::{cx_gate, h_gate, measure, EXTENSION_ID},
};

#[test]
fn simple_linear() {
let build_res = build_main(
FunctionType::new(type_row![QB, QB], type_row![QB, QB]).into(),
FunctionType::new_endo(type_row![QB, QB])
.with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID))
.into(),
|mut f_build| {
let wires = f_build.input_wires().collect();
let mut wires: [Wire; 2] = f_build.input_wires_arr();
[wires[1]] = f_build
.add_dataflow_op(
LeafOp::Lift {
type_row: vec![QB].into(),
new_extension: EXTENSION_ID,
},
[wires[1]],
)?
.outputs_arr();

let mut linear = CircuitBuilder {
wires,
wires: Vec::from(wires),
builder: &mut f_build,
};

Expand Down Expand Up @@ -184,10 +195,20 @@ mod test {
.into(),
);
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(),
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID))
.into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();

let [angle] = f_build
.add_dataflow_op(
LeafOp::Lift {
type_row: vec![NAT].into(),
new_extension: EXTENSION_ID,
},
[angle],
)?
.outputs_arr();
let mut linear = f_build.as_circuit(vec![q0, q1]);

let measure_out = linear
Expand Down
2 changes: 1 addition & 1 deletion src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ mod test {
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).into(),
)?;
let tru_const = fbuild.add_constant(Const::true_val(), ExtensionSet::new())?;
let tru_const = fbuild.add_constant(Const::true_val())?;
let _fdef = {
let const_wire = fbuild.load_const(&tru_const)?;
let [int] = fbuild.input_wires_arr();
Expand Down
48 changes: 42 additions & 6 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ pub(crate) mod test {
use crate::hugr::validate::InterGraphEdgeError;
use crate::ops::{handle::NodeHandle, LeafOp, OpTag};

use crate::std_extensions::logic::test::and_op;
use crate::std_extensions::logic::{self, test::and_op};
use crate::types::Type;
use crate::utils::test_quantum_extension::h_gate;
use crate::utils::test_quantum_extension::{self, h_gate};
use crate::{
builder::{
test::{n_identity, BIT, NAT, QB},
Expand All @@ -235,13 +235,25 @@ pub(crate) mod test {
let _f_id = {
let mut func_builder = module_builder.define_function(
"main",
FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(),
FunctionType::new(type_row![NAT, QB], type_row![NAT, QB])
.with_extension_delta(&ExtensionSet::singleton(
&test_quantum_extension::EXTENSION_ID,
))
.into(),
)?;

let [int, qb] = func_builder.input_wires_arr();

let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?;

let [int] = func_builder
.add_dataflow_op(
LeafOp::Lift {
type_row: vec![NAT].into(),
new_extension: test_quantum_extension::EXTENSION_ID,
},
[int],
)?
.outputs_arr();
let inner_builder = func_builder.dfg_builder(
FunctionType::new(type_row![NAT], type_row![NAT]),
None,
Expand All @@ -260,7 +272,7 @@ pub(crate) mod test {
}

// Scaffolding for copy insertion tests
fn copy_scaffold<F>(f: F, msg: &'static str) -> Result<(), BuildError>
fn copy_scaffold<F>(f: F, delta: &ExtensionSet, msg: &'static str) -> Result<(), BuildError>
where
F: FnOnce(FunctionBuilder<&mut Hugr>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
{
Expand All @@ -269,7 +281,9 @@ pub(crate) mod test {

let f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).into(),
FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])
.with_extension_delta(delta)
.into(),
)?;

f(f_build)?;
Expand All @@ -287,25 +301,47 @@ pub(crate) mod test {
let [b1] = f_build.input_wires_arr();
f_build.finish_with_outputs([b1, b1])
},
&ExtensionSet::new(),
"Copy input and output",
)?;

let es = ExtensionSet::singleton(&logic::EXTENSION_ID);
copy_scaffold(
|mut f_build| {
let [b1] = f_build.input_wires_arr();
let xor = f_build.add_dataflow_op(and_op(), [b1, b1])?;
let [b1] = f_build
.add_dataflow_op(
LeafOp::Lift {
type_row: vec![BOOL_T].into(),
new_extension: logic::EXTENSION_ID,
},
[b1],
)?
.outputs_arr();
f_build.finish_with_outputs([xor.out_wire(0), b1])
},
&es,
"Copy input and use with binary function",
)?;

copy_scaffold(
|mut f_build| {
let [b1] = f_build.input_wires_arr();
let xor1 = f_build.add_dataflow_op(and_op(), [b1, b1])?;
let [b1] = f_build
.add_dataflow_op(
LeafOp::Lift {
type_row: vec![BOOL_T].into(),
new_extension: logic::EXTENSION_ID,
},
[b1],
)?
.outputs_arr();
let xor2 = f_build.add_dataflow_op(and_op(), [b1, xor1.out_wire(0)])?;
f_build.finish_with_outputs([xor2.out_wire(0), b1])
},
&es,
"Copy multiple times",
)?;

Expand Down
13 changes: 3 additions & 10 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ mod test {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(
ConstUsize::new(1).into(),
ExtensionSet::singleton(&PRELUDE_ID),
)?;
let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?;

let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
loop_b.set_outputs(break_wire, [i1])?;
Expand Down Expand Up @@ -148,8 +145,7 @@ mod test {
fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?;
let signature = loop_b.loop_signature()?.clone();
let const_val = Const::true_val();
let const_wire =
loop_b.add_load_const(Const::true_val(), ExtensionSet::new())?;
let const_wire = loop_b.add_load_const(Const::true_val())?;
let lift_node = loop_b.add_dataflow_op(
ops::LeafOp::Lift {
type_row: vec![const_val.const_type().clone()].into(),
Expand Down Expand Up @@ -177,10 +173,7 @@ mod test {
let mut branch_1 = conditional_b.case_builder(1)?;
let [_b1] = branch_1.input_wires_arr();

let wire = branch_1.add_load_const(
ConstUsize::new(2).into(),
ExtensionSet::singleton(&PRELUDE_ID),
)?;
let wire = branch_1.add_load_const(ConstUsize::new(2).into())?;
let break_wire = branch_1.make_break(signature, [wire])?;
branch_1.finish_with_outputs([break_wire])?;

Expand Down
15 changes: 6 additions & 9 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,11 @@ impl UnificationContext {
match node_type.io_extensions() {
// Input extensions are open
None => {
let c = if let Some(sig) = node_type.op_signature() {
let delta = sig.extension_reqs;
if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
}
} else {
let delta = node_type.op().extension_delta();
let c = if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
};
self.add_constraint(m_output, c);
}
Expand Down Expand Up @@ -652,11 +648,12 @@ impl UnificationContext {
fn search_variable_deps(&self) -> HashSet<Meta> {
let mut seen = HashSet::new();
let mut new_variables: HashSet<Meta> = self.variables.clone();
let constraints_for_solved = HashSet::new();
while !new_variables.is_empty() {
new_variables = new_variables
.into_iter()
.filter(|m| seen.insert(*m))
.flat_map(|m| self.get_constraints(&m).unwrap())
.flat_map(|m| self.get_constraints(&m).unwrap_or(&constraints_for_solved))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICS solved constraints get removed here:

hugr/src/extension/infer.rs

Lines 633 to 635 in 6959c89

to_delete.iter().for_each(|m| {
self.constraints.remove(m);
});

so we have to handle them somehow. However I haven't looked into this all that hard so may have missed something @croyzor ?

.map(|c| match c {
Constraint::Plus(_, other) => self.resolve(*other),
Constraint::Equal(other) => self.resolve(*other),
Expand Down
Loading