Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small cleanup. Use fxHash instead #23

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/extract/bottom_up.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ pub struct BottomUpExtractor;
impl Extractor for BottomUpExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
let mut result = ExtractionResult::default();
let mut costs = IndexMap::<ClassId, Cost>::default();
let mut costs = FxHashMap::<ClassId, Cost>::with_capacity_and_hasher(
egraph.classes().len(),
Default::default(),
);
let mut did_something = false;

loop {
Expand Down
4 changes: 2 additions & 2 deletions src/extract/faster_bottom_up.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ use super::*;
/// This algorithm instead only visits the nodes whose current cost estimate may change:
/// it does this by tracking parent-child relationships and storing relevant nodes
/// in a work list (UniqueQueue).
pub struct BottomUpExtractor;
pub struct FasterBottomUpExtractor;

impl Extractor for BottomUpExtractor {
impl Extractor for FasterBottomUpExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
let mut parents = IndexMap::<ClassId, Vec<NodeId>>::with_capacity(egraph.classes().len());
let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
Expand Down
56 changes: 20 additions & 36 deletions src/extract/greedy_dag_1.rs → src/extract/faster_greedy_dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
// included in the cost.

use super::*;
use rustc_hash::{FxHashMap, FxHashSet};

struct CostSet {
costs: std::collections::HashMap<ClassId, Cost>,
// It's slightly faster if this is an HashMap rather than an fxHashMap.
costs: HashMap<ClassId, Cost>,
total: Cost,
choice: NodeId,
}
Expand All @@ -16,15 +18,15 @@ impl FasterGreedyDagExtractor {
fn calculate_cost_set(
egraph: &EGraph,
node_id: NodeId,
costs: &HashMap<ClassId, CostSet>,
costs: &FxHashMap<ClassId, CostSet>,
best_cost: Cost,
) -> CostSet {
let node = &egraph[&node_id];

// No children -> easy.
if node.children.is_empty() {
return CostSet {
costs: std::collections::HashMap::default(),
costs: Default::default(),
total: node.cost,
choice: node_id.clone(),
};
Expand All @@ -44,7 +46,7 @@ impl FasterGreedyDagExtractor {
if childrens_classes.len() == 1 && (node.cost + first_cost.total > best_cost) {
// Shortcut. Can't be cheaper so return junk.
return CostSet {
costs: std::collections::HashMap::default(),
costs: Default::default(),
total: INFINITY,
choice: node_id.clone(),
};
Expand Down Expand Up @@ -85,46 +87,35 @@ impl FasterGreedyDagExtractor {
}
}

impl FasterGreedyDagExtractor {
fn check(egraph: &EGraph, node_id: NodeId, costs: &HashMap<ClassId, CostSet>) {
let cid = egraph.nid_to_cid(&node_id);
let previous = costs.get(cid).unwrap().total;
let cs = Self::calculate_cost_set(egraph, node_id, costs, INFINITY);
println!("{} {}", cs.total, previous);
assert!(cs.total >= previous);
}
}

impl Extractor for FasterGreedyDagExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
// 1. build map from class to parent nodes
let mut parents = IndexMap::<ClassId, Vec<NodeId>>::default();
let mut parents = IndexMap::<ClassId, Vec<NodeId>>::with_capacity(egraph.classes().len());
let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
let mut analysis_pending = UniqueQueue::default();

for class in egraph.classes().values() {
parents.insert(class.id.clone(), Vec::new());
}

for class in egraph.classes().values() {
for node in &class.nodes {
for c in &egraph[node].children {
// compute parents of this enode
parents[n2c(c)].push(node.clone());
}
}
}

// 2. start analysis from leaves
let mut analysis_pending = UniqueQueue::default();

for class in egraph.classes().values() {
for node in &class.nodes {
// start the analysis from leaves
if egraph[node].is_leaf() {
analysis_pending.insert(node.clone());
}
}
}

// 3. analyse from leaves towards parents until fixpoint
let mut costs = HashMap::<ClassId, CostSet>::default();
let mut result = ExtractionResult::default();
let mut costs = FxHashMap::<ClassId, CostSet>::with_capacity_and_hasher(
egraph.classes().len(),
Default::default(),
);

while let Some(node_id) = analysis_pending.pop() {
let class_id = n2c(&node_id);
Expand All @@ -144,15 +135,6 @@ impl Extractor for FasterGreedyDagExtractor {
}
}

/*
for class in egraph.classes().values() {
for node in &class.nodes {
Self::check(&egraph, node.clone(), &costs);
}
}
*/

let mut result = ExtractionResult::default();
for (cid, cost_set) in costs {
result.choose(cid, cost_set.choice);
}
Expand All @@ -164,14 +146,16 @@ impl Extractor for FasterGreedyDagExtractor {
/** A data structure to maintain a queue of unique elements.

Notably, insert/pop operations have O(1) expected amortized runtime complexity.

Thanks @Bastacyclop for the implementation!
*/
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub(crate) struct UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
set: std::collections::HashSet<T>, // hashbrown::
set: FxHashSet<T>, // hashbrown::
queue: std::collections::VecDeque<T>,
}

Expand All @@ -181,7 +165,7 @@ where
{
fn default() -> Self {
UniqueQueue {
set: std::collections::HashSet::default(),
set: Default::default(),
queue: std::collections::VecDeque::new(),
}
}
Expand Down
22 changes: 7 additions & 15 deletions src/extract/greedy_dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ struct CostSet {
pub struct GreedyDagExtractor;
impl Extractor for GreedyDagExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
let mut costs = IndexMap::<ClassId, CostSet>::default();
let mut keep_going = true;
let mut costs = FxHashMap::<ClassId, CostSet>::with_capacity_and_hasher(
egraph.classes().len(),
Default::default(),
);

let mut nodes = egraph.nodes.clone();
let mut keep_going = true;

let mut i = 0;
while keep_going {
i += 1;
println!("iteration {}", i);
keep_going = false;

let mut to_remove = vec![];

'node_loop: for (node_id, node) in &nodes {
'node_loop: for (node_id, node) in &egraph.nodes {
let cid = egraph.nid_to_cid(node_id);
let mut cost_set = CostSet {
costs: FxHashMap::default(),
costs: Default::default(),
total: Cost::default(),
choice: node_id.clone(),
};
Expand Down Expand Up @@ -60,14 +60,6 @@ impl Extractor for GreedyDagExtractor {
costs.insert(cid.clone(), cost_set);
keep_going = true;
}
to_remove.push(node_id.clone());
}

// removing nodes you've "done" can speed it up a lot but makes the results much worse
if false {
for node_id in to_remove {
nodes.remove(&node_id);
}
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ pub use crate::*;

pub mod bottom_up;
pub mod faster_bottom_up;
pub mod faster_greedy_dag;
pub mod global_greedy_dag;
pub mod greedy_dag;
pub mod greedy_dag_1;

#[cfg(feature = "ilp-cbc")]
pub mod ilp_cbc;

Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ fn main() {
("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()),
(
"faster-bottom-up",
extract::faster_bottom_up::BottomUpExtractor.boxed(),
extract::faster_bottom_up::FasterBottomUpExtractor.boxed(),
),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
),
(
"faster-greedy-dag",
extract::greedy_dag_1::FasterGreedyDagExtractor.boxed(),
extract::faster_greedy_dag::FasterGreedyDagExtractor.boxed(),
),
(
"global-greedy-dag",
Expand Down