Skip to content

Commit

Permalink
feat: implement RemoveConst and RemoveConstIgnore (#757)
Browse files Browse the repository at this point in the history
as per spec

refactor!: allow Into<Const> for builder.add_const

BREAKING_CHANGES: existing CustomConst.into() calls will error

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Alan Lawrence <[email protected]>
Co-authored-by: Luca Mondada <[email protected]>
Co-authored-by: Luca Mondada <[email protected]>
  • Loading branch information
5 people authored Jan 3, 2024
1 parent 7f749e8 commit d5e7d63
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?;
fn add_constant(&mut self, constant: impl Into<ops::Const>) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?;

Ok(const_n.into())
}
Expand Down Expand Up @@ -374,7 +374,7 @@ pub trait Dataflow: Container {
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> {
fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
self.load_const(&cid)
}
Expand Down
4 changes: 2 additions & 2 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ mod test {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?;
let const_wire = loop_b.add_load_const(ConstUsize::new(1))?;

let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
loop_b.set_outputs(break_wire, [i1])?;
Expand Down Expand Up @@ -173,7 +173,7 @@ mod test {
let mut branch_1 = conditional_b.case_builder(1)?;
let [_b1] = branch_1.input_wires_arr();

let wire = branch_1.add_load_const(ConstUsize::new(2).into())?;
let wire = branch_1.add_load_const(ConstUsize::new(2))?;
let break_wire = branch_1.make_break(signature, [wire])?;
branch_1.finish_with_outputs([break_wire])?;

Expand Down
1 change: 1 addition & 0 deletions src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Rewrite operations on the HUGR - replacement, outlining, etc.
pub mod consts;
pub mod insert_identity;
pub mod outline_cfg;
pub mod replace;
Expand Down
214 changes: 214 additions & 0 deletions src/hugr/rewrite/consts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
//! Rewrite operations involving Const and LoadConst operations
use std::iter;

use crate::{
hugr::{HugrError, HugrMut},
HugrView, Node,
};

use itertools::Itertools;
use thiserror::Error;

use super::Rewrite;

/// Remove a [`crate::ops::LoadConstant`] node with no consumers.
#[derive(Debug, Clone)]
pub struct RemoveConstIgnore(pub Node);

/// Error from an [`RemoveConst`] or [`RemoveConstIgnore`] operation.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RemoveError {
/// Invalid node.
#[error("Node is invalid (either not in HUGR or not correct operation).")]
InvalidNode(Node),
/// Node in use.
#[error("Node: {0:?} has non-zero outgoing connections.")]
ValueUsed(Node),
/// Removal error
#[error("Removing node caused error: {0:?}.")]
RemoveFail(#[from] HugrError),
}

impl Rewrite for RemoveConstIgnore {
type Error = RemoveError;

// The Const node the LoadConstant was connected to.
type ApplyResult = Node;

type InvalidationSet<'a> = iter::Once<Node>;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
let node = self.0;

if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
return Err(RemoveError::InvalidNode(node));
}

if h.out_value_types(node)
.next()
.is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some())
{
return Err(RemoveError::ValueUsed(node));
}

Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
let node = self.0;
let source = h
.input_neighbours(node)
.exactly_one()
.ok()
.expect("Validation should check a Const is connected to LoadConstant.");
h.remove_node(node)?;

Ok(source)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.0)
}
}

/// Remove a [`crate::ops::Const`] node with no outputs.
#[derive(Debug, Clone)]
pub struct RemoveConst(pub Node);

impl Rewrite for RemoveConst {
type Error = RemoveError;

// The parent of the Const node.
type ApplyResult = Node;

type InvalidationSet<'a> = iter::Once<Node>;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
let node = self.0;

if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
return Err(RemoveError::InvalidNode(node));
}

if h.output_neighbours(node).next().is_some() {
return Err(RemoveError::ValueUsed(node));
}

Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
let node = self.0;
let parent = h
.get_parent(node)
.expect("Const node without a parent shouldn't happen.");
h.remove_node(node)?;

Ok(parent)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.0)
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::{
builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer},
extension::{
prelude::{ConstUsize, USIZE_T},
PRELUDE_REGISTRY,
},
hugr::HugrMut,
ops::{handle::NodeHandle, LeafOp},
type_row,
types::FunctionType,
};
#[test]
fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> {
let mut build = ModuleBuilder::new();
let con_node = build.add_constant(ConstUsize::new(2))?;

let mut dfg_build =
build.define_function("main", FunctionType::new_endo(type_row![]).into())?;
let load_1 = dfg_build.load_const(&con_node)?;
let load_2 = dfg_build.load_const(&con_node)?;
let tup = dfg_build.add_dataflow_op(
LeafOp::MakeTuple {
tys: type_row![USIZE_T, USIZE_T],
},
[load_1, load_2],
)?;
dfg_build.finish_sub_container()?;

let mut h = build.finish_prelude_hugr()?;
// nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple
assert_eq!(h.node_count(), 8);
let tup_node = tup.node();
// can't remove invalid node
assert_eq!(
h.apply_rewrite(RemoveConst(tup_node)),
Err(RemoveError::InvalidNode(tup_node))
);

assert_eq!(
h.apply_rewrite(RemoveConstIgnore(tup_node)),
Err(RemoveError::InvalidNode(tup_node))
);
let load_1_node = load_1.node();
let load_2_node = load_2.node();
let con_node = con_node.node();

let remove_1 = RemoveConstIgnore(load_1_node);
assert_eq!(
remove_1.invalidation_set().exactly_one().ok(),
Some(load_1_node)
);

let remove_2 = RemoveConstIgnore(load_2_node);

let remove_con = RemoveConst(con_node);
assert_eq!(
remove_con.invalidation_set().exactly_one().ok(),
Some(con_node)
);

// can't remove nodes in use
assert_eq!(
h.apply_rewrite(remove_1.clone()),
Err(RemoveError::ValueUsed(load_1_node))
);

// remove the use
h.remove_node(tup_node)?;

// remove first load
let reported_con_node = h.apply_rewrite(remove_1)?;
assert_eq!(reported_con_node, con_node);

// still can't remove const, in use by second load
assert_eq!(
h.apply_rewrite(remove_con.clone()),
Err(RemoveError::ValueUsed(con_node))
);

// remove second use
let reported_con_node = h.apply_rewrite(remove_2)?;
assert_eq!(reported_con_node, con_node);
// remove const
assert_eq!(h.apply_rewrite(remove_con)?, h.root());

assert_eq!(h.node_count(), 4);
assert!(h.validate(&PRELUDE_REGISTRY).is_ok());
Ok(())
}
}
2 changes: 1 addition & 1 deletion src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ fn static_targets() {
)
.unwrap();

let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap();
let c = dfg.add_constant(ConstUsize::new(1)).unwrap();

let load = dfg.load_const(&c).unwrap();

Expand Down

0 comments on commit d5e7d63

Please sign in to comment.