Skip to content

Commit

Permalink
add docstrings and simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Dec 22, 2023
1 parent 0e0411f commit 8e88f3e
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ pub fn find_consts<'a, 'r: 'a>(
candidate_nodes: impl IntoIterator<Item = Node> + 'a,
reg: &'r ExtensionRegistry,
) -> impl Iterator<Item = (SimpleReplacement, Vec<RemoveConstIgnore>)> + 'a {
// track nodes for operations that have already been considered for folding
let mut used_neighbours = BTreeSet::new();

candidate_nodes
.into_iter()
.filter_map(move |n| {
// only look at LoadConstant
hugr.get_optype(n).is_load_constant().then_some(())?;

let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?;
Expand All @@ -118,6 +120,7 @@ pub fn find_consts<'a, 'r: 'a>(
.filter(|(n, _)| used_neighbours.insert(*n))
.collect_vec();
if neighbours.is_empty() {
// no uses of LoadConstant that haven't already been considered.
return None;
}
let fold_iter = neighbours
Expand All @@ -128,16 +131,21 @@ pub fn find_consts<'a, 'r: 'a>(
.flatten()
}

/// Attempt to evaluate and generate rewrites for the operation at `op_node`
fn fold_op(
hugr: &impl HugrView,
op_node: Node,
reg: &ExtensionRegistry,
) -> Option<(SimpleReplacement, Vec<RemoveConstIgnore>)> {
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr
.node_inputs(op_node)
.filter_map(|in_p| get_const(hugr, op_node, in_p))
.filter_map(|in_p| {
let (con_op, load_n) = get_const(hugr, op_node, in_p)?;
Some(((in_p, con_op), RemoveConstIgnore(load_n)))
})
.unzip();
let neighbour_op = hugr.get_optype(op_node);
// attempt to evaluate op
let folded = fold_const(neighbour_op, &in_consts)?;
let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip();
let nu_out = op_outs
Expand All @@ -152,7 +160,7 @@ fn fold_op(
.collect();
let replacement = const_graph(consts, reg);
let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr)
.expect("Load consts and operation should form valid subgraph.");
.expect("Operation should form valid subgraph.");

let simple_replace = SimpleReplacement::new(
sibling_graph,
Expand All @@ -164,11 +172,9 @@ fn fold_op(
Some((simple_replace, removals))
}

fn get_const(
hugr: &impl HugrView,
op_node: Node,
in_p: IncomingPort,
) -> Option<((IncomingPort, Const), RemoveConstIgnore)> {
/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant
/// and the LoadConstant node
fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> {
let (load_n, _) = hugr.single_linked_output(op_node, in_p)?;
let load_op = hugr.get_optype(load_n).as_load_constant()?;
let const_node = hugr
Expand All @@ -180,7 +186,7 @@ fn get_const(
let const_op = hugr.get_optype(const_node).as_const()?;

// TODO avoid const clone here
Some(((in_p, const_op.clone()), RemoveConstIgnore(load_n)))
Some((const_op.clone(), load_n))
}

/// Exhaustively apply constant folding to a HUGR.
Expand Down

0 comments on commit 8e88f3e

Please sign in to comment.