Skip to content

Commit

Permalink
comments and cleanup- thanks @ezrosent
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Dec 11, 2023
1 parent ab61dfa commit 64d98df
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions src/extract/global_greedy_dag.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter;

use rpds::{HashTrieMap, HashTrieSet};

use super::*;
Expand All @@ -15,19 +17,19 @@ type Reachable = HashTrieSet<ClassId>;
struct TermInfo {
node: NodeId,
eclass: ClassId,
node_cost: NotNan<f64>,
total_cost: NotNan<f64>,
node_cost: Cost,
total_cost: Cost,
// store the set of reachable terms from this term
reachable: Reachable,
size: usize,
}

// A TermDag needs to store terms that share common
// subterms using a hashmap.
// However, it also critically needs to be able to answer
// reachability queries in this dag `reachable`.
// This prevents double-counting costs when
// computing the cost of a term.
/// A TermDag needs to store terms that share common
/// subterms using a hashmap.
/// However, it also critically needs to be able to answer
/// reachability queries in this dag `reachable`.
/// This prevents double-counting costs when
/// computing the cost of a term.
#[derive(Default)]
pub struct TermDag {
nodes: Vec<Term>,
Expand All @@ -36,16 +38,16 @@ pub struct TermDag {
}

impl TermDag {
// Makes a new term using a node and children terms
// Correctly computes total_cost with sharing
// If this term contains itself, returns None
// If this term costs more than target, returns None
/// Makes a new term using a node and children terms
/// Correctly computes total_cost with sharing
/// If this term contains itself, returns None
/// If this term costs more than target, returns None
pub fn make(
&mut self,
node_id: NodeId,
node: &Node,
children: Vec<TermId>,
target: NotNan<f64>,
target: Cost,
) -> Option<TermId> {
let term = Term {
op: node.op.clone(),
Expand All @@ -66,13 +68,15 @@ impl TermDag {
eclass: node.eclass.clone(),
node_cost,
total_cost: node_cost,
reachable: [node.eclass.clone()].into_iter().collect(),
reachable: iter::once(node.eclass.clone()).collect(),
size: 1,
});
self.hash_cons.insert(term, next_id);
Some(next_id)
} else {
// check if children contains this node
// check if children contains this node, preventing cycles
// This is sound because `reachable` is the set of reachable eclasses
// from this term.
for child in &children {
if self.info[*child].reachable.contains(&node.eclass) {
return None;
Expand Down Expand Up @@ -115,10 +119,16 @@ impl TermDag {
}
}

// Return a new term, like this one but making use of shared terms.
// Also return the cost of the new nodes.
fn get_cost(&self, shared: &mut Box<Reachable>, id: TermId) -> NotNan<f64> {
/// Return a new term, like this one but making use of shared terms.
/// Also return the cost of the new nodes.
fn get_cost(&self, shared: &mut Box<Reachable>, id: TermId) -> Cost {
let eclass = self.info[id].eclass.clone();

// This is the key to why this algorithm is faster than greedy_dag.
// While doing the set union between reachable sets, we can stop early
// if we find a shared term.
// Since the term with `id` is shared, the reachable set of `id` will already
// be in `shared`.
if shared.contains(&eclass) {
NotNan::<f64>::new(0.0).unwrap()
} else {
Expand All @@ -132,11 +142,11 @@ impl TermDag {
}
}

pub fn node_cost(&self, id: TermId) -> NotNan<f64> {
pub fn node_cost(&self, id: TermId) -> Cost {
self.info[id].node_cost
}

pub fn total_cost(&self, id: TermId) -> NotNan<f64> {
pub fn total_cost(&self, id: TermId) -> Cost {
self.info[id].total_cost
}
}
Expand Down

0 comments on commit 64d98df

Please sign in to comment.