Skip to content

Commit

Permalink
Even faster greedy dag
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorHansen committed Sep 13, 2023
1 parent 728954f commit 5180a1f
Showing 1 changed file with 49 additions and 43 deletions.
92 changes: 49 additions & 43 deletions src/extract/greedy_dag_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,62 +20,68 @@ impl FasterGreedyDagExtractor {
best_cost: Cost,
) -> CostSet {
let node = &egraph[&node_id];
let cid = egraph.nid_to_cid(&node_id);

let mut desc = 0;
let mut children_cost = Cost::default();
for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
let cs = costs.get(child_cid).unwrap();
desc += cs.costs.len();
children_cost += cs.total;
// No children -> easy.
if node.children.is_empty() {
return CostSet {
costs: std::collections::HashMap::default(),
total: node.cost,
choice: node_id.clone(),
};
}

let mut cost_set = CostSet {
costs: std::collections::HashMap::with_capacity(desc),
total: Cost::default(),
choice: node_id.clone(),
};

if node.children.len() == 1 && (node.cost + children_cost > best_cost) {
// shortcut. Not going to be better - don't bother filling the costs.
cost_set.total = node.cost + children_cost;
return cost_set;
// Get unique classes of children.
let mut childrens_classes = node
.children
.iter()
.map(|c| egraph.nid_to_cid(&c).clone())
.collect::<Vec<ClassId>>();
childrens_classes.sort();
childrens_classes.dedup();

let first_cost = costs.get(&childrens_classes[0]).unwrap();

if childrens_classes.len() == 1 && (node.cost + first_cost.total > best_cost) {
// Shortcut. Can't be cheaper so return junk.
return CostSet {
costs: std::collections::HashMap::default(),
total: INFINITY,
choice: node_id.clone(),
};
}

if node.children.len() == 1 {
// just clone it.
let child = &node.children[0];
let child_cid = egraph.nid_to_cid(&child);
cost_set.costs = costs.get(child_cid).unwrap().costs.clone();
} else {
for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
let costs = &costs.get(child_cid).unwrap().costs;
for (key, value) in costs.iter() {
cost_set.costs.insert(key.clone(), value.clone());
}
// Clone the biggest set and insert the others into it.
let id_of_biggest = childrens_classes
.iter()
.max_by_key(|s| costs.get(s).unwrap().costs.len())
.unwrap();
let mut result = costs.get(&id_of_biggest).unwrap().costs.clone();
for child_cid in &childrens_classes {
if child_cid == id_of_biggest {
continue;
}

//cost_set.costs.extend(costs);
let next_cost = &costs.get(child_cid).unwrap().costs;
for (key, value) in next_cost.iter() {
result.insert(key.clone(), value.clone());
}
}

let contains = cost_set.costs.contains_key(&cid.clone());
cost_set.costs.insert(cid.clone(), node.cost); // this node.
let cid = egraph.nid_to_cid(&node_id);
let contains = result.contains_key(&cid);
result.insert(cid.clone(), node.cost);

if contains {
cost_set.total = INFINITY;
let result_cost = if contains {
INFINITY
} else {
if cost_set.costs.len() == desc + 1 {
// No extra duplicates are found, so the cost is the current
// nodes cost + the children's cost.
cost_set.total = children_cost + node.cost;
} else {
cost_set.total = cost_set.costs.values().sum();
}
result.values().sum()
};

cost_set
return CostSet {
costs: result,
total: result_cost,
choice: node_id.clone(),
};
}
}

Expand Down

0 comments on commit 5180a1f

Please sign in to comment.