Skip to content

Commit

Permalink
fix: Sibling extension panics while computing signature with non-data…
Browse files Browse the repository at this point in the history
…flow nodes (#1350)

The signature computation for a `SiblingSubgraph` takes the union of the
nodes' extensions. This didn't contemplate non-dataflow nodes like
constants, and caused a runtime panic if one was present.

Most of the diff is adding a constant node in the tests. 

This is a fix for CQCL/tket2#507
  • Loading branch information
aborgna-q committed Jul 25, 2024
1 parent c21a999 commit 5fce36e
Showing 1 changed file with 49 additions and 26 deletions.
75 changes: 49 additions & 26 deletions hugr-core/src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,9 @@ impl SiblingSubgraph {
})
.collect_vec();
Signature::new(input, output).with_extension_delta(ExtensionSet::union_over(
self.nodes.iter().map(|n| {
hugr.signature(*n)
.expect("all nodes must have dataflow signature")
.extension_reqs
}),
self.nodes
.iter()
.map(|n| hugr.get_optype(*n).extension_delta()),
))
}

Expand Down Expand Up @@ -729,14 +727,14 @@ pub enum InvalidSubgraphBoundary {

#[cfg(test)]
mod tests {
use std::error::Error;

use cool_asserts::assert_matches;

use crate::builder::inout_sig;
use crate::extension::PRELUDE_REGISTRY;
use crate::extension::{prelude, ExtensionRegistry};
use crate::ops::Const;
use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::std_extensions::logic;
use crate::utils::test_quantum_extension::{self, cx_gate};
use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64};
use crate::{
builder::{
BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
Expand Down Expand Up @@ -778,22 +776,36 @@ mod tests {
}
}

/// A Module with a single function from three qubits to three qubits.
/// The function applies a CX gate to the first two qubits and a Rz gate (with a constant angle) to the last qubit.
fn build_hugr() -> Result<(Hugr, Node), BuildError> {
let mut mod_builder = ModuleBuilder::new();
let func = mod_builder.declare(
"test",
Signature::new_endo(type_row![QB_T, QB_T, QB_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.with_extension_delta(ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
float_types::EXTENSION_ID,
]))
.into(),
)?;
let func_id = {
let mut dfg = mod_builder.define_declaration(&func)?;
let [w0, w1, w2] = dfg.input_wires_arr();
let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
let c = dfg.add_load_const(Const::new(ConstF64::new(0.5).into()));
let [w2] = dfg.add_dataflow_op(rz_f64(), [w2, c])?.outputs_arr();
dfg.finish_with_outputs([w0, w1, w2])?
};
let hugr = mod_builder
.finish_prelude_hugr()
.finish_hugr(
&ExtensionRegistry::try_new([
prelude::PRELUDE.to_owned(),
test_quantum_extension::EXTENSION.to_owned(),
float_types::EXTENSION.to_owned(),
])
.unwrap(),
)
.map_err(|e| -> BuildError { e.into() })?;
Ok((hugr, func_id.node()))
}
Expand Down Expand Up @@ -888,16 +900,17 @@ mod tests {
let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;

let empty_dfg = {
let builder = DFGBuilder::new(Signature::new_endo(type_row![QB_T, QB_T])).unwrap();
let builder =
DFGBuilder::new(Signature::new_endo(type_row![QB_T, QB_T, QB_T])).unwrap();
let inputs = builder.input_wires();
builder.finish_prelude_hugr_with_outputs(inputs).unwrap()
};

let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap();

assert_eq!(rep.subgraph().nodes().len(), 1);
assert_eq!(rep.subgraph().nodes().len(), 4);

assert_eq!(hugr.node_count(), 5); // Module + Def + In + CX + Out
assert_eq!(hugr.node_count(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out
hugr.apply_rewrite(rep).unwrap();
assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out

Expand All @@ -909,12 +922,14 @@ mod tests {
let (hugr, dfg) = build_hugr().unwrap();
let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, dfg).unwrap();
let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
// The identity wire on the third qubit is ignored, so the subgraph's signature only contains
// the first two qubits.
assert_eq!(
sub.signature(&func),
Signature::new_endo(type_row![QB_T, QB_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
Signature::new_endo(type_row![QB_T, QB_T, QB_T]).with_extension_delta(
ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
float_types::EXTENSION_ID,
])
)
);
Ok(())
}
Expand Down Expand Up @@ -947,7 +962,7 @@ mod tests {
.unwrap()
.nodes()
.len(),
1
4
)
}

Expand Down Expand Up @@ -1064,15 +1079,23 @@ mod tests {
}

#[test]
fn extract_subgraph() -> Result<(), Box<dyn Error>> {
let (hugr, func_root) = build_hugr()?;
let func_graph: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, func_root)?;
let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph)?;
fn extract_subgraph() {
let (hugr, func_root) = build_hugr().unwrap();
let func_graph: SiblingGraph<'_, FuncID<true>> =
SiblingGraph::try_new(&hugr, func_root).unwrap();
let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
let extracted = subgraph.extract_subgraph(&hugr, "region");

extracted.validate(&PRELUDE_REGISTRY)?;

Ok(())
extracted
.validate(
&ExtensionRegistry::try_new([
prelude::PRELUDE.to_owned(),
test_quantum_extension::EXTENSION.to_owned(),
float_types::EXTENSION.to_owned(),
])
.unwrap(),
)
.unwrap();
}

#[test]
Expand Down

0 comments on commit 5fce36e

Please sign in to comment.