Skip to content

Commit

Permalink
refactor: Add From conversion from ExtensionId to ExtensionSet (#855)
Browse files Browse the repository at this point in the history
This case comes up too often to keep writing out
`ExtensionSet::singleton`
  • Loading branch information
croyzor authored Mar 5, 2024
1 parent 697e7d7 commit 6be3ca2
Show file tree
Hide file tree
Showing 17 changed files with 95 additions and 122 deletions.
17 changes: 7 additions & 10 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub trait Container {
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(&mut self, constant: impl Into<ops::Const>) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?;
let const_n = self.add_child_node(NodeType::new_pure(constant.into()))?;

Ok(const_n.into())
}
Expand All @@ -89,13 +89,10 @@ pub trait Container {
signature: PolyFuncType,
) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
let body = signature.body().clone();
let f_node = self.add_child_node(NodeType::new(
ops::FuncDefn {
name: name.into(),
signature,
},
ExtensionSet::new(),
))?;
let f_node = self.add_child_node(NodeType::new_pure(ops::FuncDefn {
name: name.into(),
signature,
}))?;

let db =
DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, Some(ExtensionSet::new()))?;
Expand Down Expand Up @@ -335,9 +332,9 @@ pub trait Dataflow: Container {
NodeType::new(
ops::CFG {
signature: FunctionType::new(inputs.clone(), output_types.clone())
.with_extension_delta(&extension_delta),
.with_extension_delta(extension_delta),
},
input_extensions,
input_extensions.into(),
),
input_wires,
)?;
Expand Down
4 changes: 2 additions & 2 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
let outputs = cond.outputs;
let case_op = ops::Case {
signature: FunctionType::new(inputs.clone(), outputs.clone())
.with_extension_delta(&extension_delta),
.with_extension_delta(extension_delta.clone()),
};
let case_node =
// add case before any existing subsequent cases
Expand All @@ -137,7 +137,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
let dfg_builder = DFGBuilder::create_with_io(
self.hugr_mut(),
case_node,
FunctionType::new(inputs, outputs).with_extension_delta(&extension_delta),
FunctionType::new(inputs, outputs).with_extension_delta(extension_delta),
None,
)?;

Expand Down
13 changes: 6 additions & 7 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,22 +428,21 @@ pub(crate) mod test {
fn lift_node() -> Result<(), BuildError> {
let xa: ExtensionId = "A".try_into().unwrap();
let xb: ExtensionId = "B".try_into().unwrap();
let xc = "C".try_into().unwrap();
let xc: ExtensionId = "C".try_into().unwrap();
let ab_extensions = ExtensionSet::from_iter([xa.clone(), xb.clone()]);
let c_extensions = ExtensionSet::singleton(&xc);
let abc_extensions = ab_extensions.clone().union(&c_extensions);
let abc_extensions = ab_extensions.clone().union(&xc.clone().into());

let parent_sig =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&abc_extensions);
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(abc_extensions);
let mut parent = DFGBuilder::new(parent_sig)?;

let add_c_sig =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&c_extensions);
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(xc.clone());

let [w] = parent.input_wires_arr();

let add_ab_sig =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&ab_extensions);
let add_ab_sig = FunctionType::new(type_row![BIT], type_row![BIT])
.with_extension_delta(ab_extensions.clone());

// A box which adds extensions A and B, via child Lift nodes
let mut add_ab = parent.dfg_builder(add_ab_sig, Some(ExtensionSet::new()), [w])?;
Expand Down
2 changes: 1 addition & 1 deletion src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ mod test {
let mut fbuild = module_builder.define_function(
"main",
FunctionType::new(type_row![BIT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID))
.with_extension_delta(PRELUDE_ID)
.into(),
)?;
let _fdef = {
Expand Down
12 changes: 9 additions & 3 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,14 @@ pub struct Extension {
impl Extension {
/// Creates a new extension with the given name.
pub fn new(name: ExtensionId) -> Self {
Self::new_with_reqs(name, Default::default())
Self::new_with_reqs(name, ExtensionSet::default())
}

/// Creates a new extension with the given name and requirements.
pub fn new_with_reqs(name: ExtensionId, extension_reqs: ExtensionSet) -> Self {
pub fn new_with_reqs(name: ExtensionId, extension_reqs: impl Into<ExtensionSet>) -> Self {
Self {
name,
extension_reqs,
extension_reqs: extension_reqs.into(),
types: Default::default(),
values: Default::default(),
operations: Default::default(),
Expand Down Expand Up @@ -502,6 +502,12 @@ impl ExtensionSet {
}
}

impl From<ExtensionId> for ExtensionSet {
fn from(id: ExtensionId) -> Self {
Self::singleton(&id)
}
}

fn as_typevar(e: &ExtensionId) -> Option<usize> {
// Type variables are represented as radix-10 numbers, which are illegal
// as standard ExtensionIds. Hence if an ExtensionId starts with a digit,
Expand Down
86 changes: 34 additions & 52 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const_extension_ids! {
// them.
fn from_graph() -> Result<(), Box<dyn Error>> {
let rs = ExtensionSet::from_iter([A, B, C]);
let main_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(&rs);
let main_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(rs);

let op = ops::DFG {
signature: main_sig,
Expand All @@ -57,17 +57,14 @@ fn from_graph() -> Result<(), Box<dyn Error>> {

assert_matches!(hugr.get_io(hugr.root()), Some(_));

let add_a_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A));
let add_a_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A);

let add_b_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&B));
let add_b_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(B);

let add_ab_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::from_iter([A, B]));
.with_extension_delta(ExtensionSet::from_iter([A, B]));

let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&C));
let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(C);

let add_a = hugr.add_node_with_parent(
hugr.root(),
Expand Down Expand Up @@ -128,16 +125,10 @@ fn plus() -> Result<(), InferExtensionError> {
})
.collect();

ctx.solved.insert(metas[2], ExtensionSet::singleton(&A));
ctx.solved.insert(metas[2], A.into());
ctx.add_constraint(metas[1], Constraint::Equal(metas[2]));
ctx.add_constraint(
metas[0],
Constraint::Plus(ExtensionSet::singleton(&B), metas[2]),
);
ctx.add_constraint(
metas[4],
Constraint::Plus(ExtensionSet::singleton(&C), metas[0]),
);
ctx.add_constraint(metas[0], Constraint::Plus(B.into(), metas[2]));
ctx.add_constraint(metas[4], Constraint::Plus(C.into(), metas[0]));
ctx.add_constraint(metas[3], Constraint::Equal(metas[4]));
ctx.add_constraint(metas[5], Constraint::Equal(metas[0]));
ctx.main_loop()?;
Expand All @@ -164,8 +155,7 @@ fn plus() -> Result<(), InferExtensionError> {
// because of a missing lift node
fn missing_lift_node() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A)),
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A),
}));

let input = hugr.add_node_with_parent(
Expand Down Expand Up @@ -211,8 +201,8 @@ fn open_variables() -> Result<(), InferExtensionError> {
.insert((NodeIndex::new(4).into(), Direction::Incoming), ab);
ctx.variables.insert(a);
ctx.variables.insert(b);
ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&A), b));
ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&B), a));
ctx.add_constraint(ab, Constraint::Plus(A.into(), b));
ctx.add_constraint(ab, Constraint::Plus(B.into(), a));
let solution = ctx.main_loop()?;
// We'll only find concrete solutions for the Incoming extension reqs of
// the main node created by `Hugr::default`
Expand All @@ -227,11 +217,12 @@ fn dangling_src() -> Result<(), Box<dyn Error>> {
let rs = ExtensionSet::singleton(&"R".try_into().unwrap());

let mut hugr = closed_dfg_root_hugr(
FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs),
FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(rs.clone()),
);

let [input, output] = hugr.get_io(hugr.root()).unwrap();
let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);
let add_r_sig =
FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(rs.clone());

let add_r = hugr.add_node_with_parent(
hugr.root(),
Expand All @@ -241,8 +232,7 @@ fn dangling_src() -> Result<(), Box<dyn Error>> {
)?;

// Dangling thingy
let src_sig =
FunctionType::new(type_row![], type_row![NAT]).with_extension_delta(&ExtensionSet::new());
let src_sig = FunctionType::new(type_row![], type_row![NAT]);

let src = hugr.add_node_with_parent(hugr.root(), ops::DFG { signature: src_sig })?;

Expand Down Expand Up @@ -365,7 +355,7 @@ fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
let conditional_node = hugr.root();

let case_op = ops::Case {
signature: FunctionType::new(inputs, outputs).with_extension_delta(&rs),
signature: FunctionType::new(inputs, outputs).with_extension_delta(rs),
};
let case0_node = build_case(&mut hugr, conditional_node, case_op.clone(), A, B)?;

Expand Down Expand Up @@ -393,7 +383,7 @@ fn extension_adding_sequence() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::new_open(ops::DFG {
signature: df_sig
.clone()
.with_extension_delta(&ExtensionSet::from_iter([A, B])),
.with_extension_delta(ExtensionSet::from_iter([A, B])),
}));

let root = hugr.root();
Expand All @@ -414,9 +404,7 @@ fn extension_adding_sequence() -> Result<(), Box<dyn Error>> {
let df_nodes: Vec<Node> = vec![A, A, B, B, A, B]
.into_iter()
.map(|ext| {
let dfg_sig = df_sig
.clone()
.with_extension_delta(&ExtensionSet::singleton(&ext));
let dfg_sig = df_sig.clone().with_extension_delta(ext.clone());
let [node, input, output] = create_with_io(
&mut hugr,
root,
Expand Down Expand Up @@ -468,7 +456,7 @@ fn make_block(
let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect();
let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone());
let dfb_sig = FunctionType::new(inputs.clone(), vec![tuple_sum_type])
.with_extension_delta(&extension_delta.clone());
.with_extension_delta(extension_delta.clone());
let dfb = ops::DataflowBlock {
inputs,
other_outputs: type_row![],
Expand Down Expand Up @@ -554,14 +542,11 @@ fn create_entry_exit(
/// +-------------------------+
#[test]
fn infer_cfg_test() -> Result<(), Box<dyn Error>> {
let a = ExtensionSet::singleton(&A);
let abc = ExtensionSet::from_iter([A, B, C]);
let bc = ExtensionSet::from_iter([B, C]);
let b = ExtensionSet::singleton(&B);
let c = ExtensionSet::singleton(&C);

let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc),
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(abc),
}));

let root = hugr.root();
Expand All @@ -571,15 +556,15 @@ fn infer_cfg_test() -> Result<(), Box<dyn Error>> {
root,
type_row![NAT],
vec![type_row![NAT], type_row![NAT]],
a.clone(),
A.into(),
type_row![NAT],
)?;

let mkpred = hugr.add_node_with_parent(
entry,
make_opaque(
A,
FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a),
FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(A),
),
)?;

Expand All @@ -600,23 +585,23 @@ fn infer_cfg_test() -> Result<(), Box<dyn Error>> {
root,
type_row![NAT],
vec![type_row![NAT], type_row![NAT]],
b.clone(),
B.into(),
)?;

let bb10 = make_block(
&mut hugr,
root,
type_row![NAT],
vec![type_row![NAT]],
c.clone(),
C.into(),
)?;

let bb11 = make_block(
&mut hugr,
root,
type_row![NAT],
vec![type_row![NAT]],
c.clone(),
C.into(),
)?;

// CFG Wiring
Expand Down Expand Up @@ -743,7 +728,7 @@ fn make_looping_cfg(

let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&hugr_delta),
.with_extension_delta(hugr_delta),
}));

let root = hugr.root();
Expand All @@ -761,7 +746,7 @@ fn make_looping_cfg(
entry,
make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext),
FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(entry_ext),
),
)?;

Expand Down Expand Up @@ -818,10 +803,9 @@ fn simple_cfg_loop() -> Result<(), Box<dyn Error>> {

let mut hugr = Hugr::new(NodeType::new(
ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&just_a),
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A),
},
just_a.clone(),
Some(A.into()),
));

let root = hugr.root();
Expand Down Expand Up @@ -865,8 +849,7 @@ fn simple_cfg_loop() -> Result<(), Box<dyn Error>> {
#[test]
fn plus_on_self() -> Result<(), Box<dyn std::error::Error>> {
let ext = ExtensionId::new("unknown1").unwrap();
let delta = ExtensionSet::singleton(&ext);
let ft = FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(&delta);
let ft = FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(ext.clone());
let mut dfg = DFGBuilder::new(ft.clone())?;

// While https://github.com/CQCL/hugr/issues/388 is unsolved,
Expand All @@ -880,8 +863,7 @@ fn plus_on_self() -> Result<(), Box<dyn std::error::Error>> {
ft,
))
.into();
let unary_sig = FunctionType::new_endo(type_row![QB_T])
.with_extension_delta(&ExtensionSet::singleton(&ext));
let unary_sig = FunctionType::new_endo(type_row![QB_T]).with_extension_delta(ext.clone());
let unop: LeafOp = ExternalOp::Opaque(OpaqueOp::new(
ext,
"1qb_op",
Expand Down Expand Up @@ -957,7 +939,7 @@ fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
let mut func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.with_extension_delta(A)
.into(),
)?;

Expand All @@ -982,7 +964,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
let mut func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.with_extension_delta(A)
.into(),
)?;

Expand Down Expand Up @@ -1017,7 +999,7 @@ fn funcdefn_signature_mismatch2() -> Result<(), Box<dyn Error>> {
let func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.with_extension_delta(A)
.into(),
)?;

Expand Down
Loading

0 comments on commit 6be3ca2

Please sign in to comment.