From 5fc973950f3b53858de083c0f088a84b1e86dc67 Mon Sep 17 00:00:00 2001 From: Trevor Hansen Date: Fri, 15 Dec 2023 15:33:57 +1100 Subject: [PATCH 1/2] Cleanup and move to fxHash --- src/extract/bottom_up.rs | 5 +- src/extract/faster_bottom_up.rs | 4 +- .../{greedy_dag_1.rs => faster_greedy_dag.rs} | 55 +++++++------------ src/extract/greedy_dag.rs | 22 +++----- src/extract/mod.rs | 3 +- src/main.rs | 4 +- 6 files changed, 35 insertions(+), 58 deletions(-) rename src/extract/{greedy_dag_1.rs => faster_greedy_dag.rs} (79%) diff --git a/src/extract/bottom_up.rs b/src/extract/bottom_up.rs index 22a2f23..e2f71db 100644 --- a/src/extract/bottom_up.rs +++ b/src/extract/bottom_up.rs @@ -4,7 +4,10 @@ pub struct BottomUpExtractor; impl Extractor for BottomUpExtractor { fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult { let mut result = ExtractionResult::default(); - let mut costs = IndexMap::::default(); + let mut costs = FxHashMap::::with_capacity_and_hasher( + egraph.classes().len(), + Default::default(), + ); let mut did_something = false; loop { diff --git a/src/extract/faster_bottom_up.rs b/src/extract/faster_bottom_up.rs index b10ff65..de3a885 100644 --- a/src/extract/faster_bottom_up.rs +++ b/src/extract/faster_bottom_up.rs @@ -14,9 +14,9 @@ use super::*; /// This algorithm instead only visits the nodes whose current cost estimate may change: /// it does this by tracking parent-child relationships and storing relevant nodes /// in a work list (UniqueQueue). -pub struct BottomUpExtractor; +pub struct FasterBottomUpExtractor; -impl Extractor for BottomUpExtractor { +impl Extractor for FasterBottomUpExtractor { fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult { let mut parents = IndexMap::>::with_capacity(egraph.classes().len()); let n2c = |nid: &NodeId| egraph.nid_to_cid(nid); diff --git a/src/extract/greedy_dag_1.rs b/src/extract/faster_greedy_dag.rs similarity index 79% rename from src/extract/greedy_dag_1.rs rename to src/extract/faster_greedy_dag.rs index ea4c76b..b672404 100644 --- a/src/extract/greedy_dag_1.rs +++ b/src/extract/faster_greedy_dag.rs @@ -3,9 +3,10 @@ // included in the cost. use super::*; +use rustc_hash::{FxHashMap, FxHashSet}; struct CostSet { - costs: std::collections::HashMap, + costs: FxHashMap, total: Cost, choice: NodeId, } @@ -16,7 +17,7 @@ impl FasterGreedyDagExtractor { fn calculate_cost_set( egraph: &EGraph, node_id: NodeId, - costs: &HashMap, + costs: &FxHashMap, best_cost: Cost, ) -> CostSet { let node = &egraph[&node_id]; @@ -24,7 +25,7 @@ impl FasterGreedyDagExtractor { // No children -> easy. if node.children.is_empty() { return CostSet { - costs: std::collections::HashMap::default(), + costs: Default::default(), total: node.cost, choice: node_id.clone(), }; @@ -44,7 +45,7 @@ impl FasterGreedyDagExtractor { if childrens_classes.len() == 1 && (node.cost + first_cost.total > best_cost) { // Shortcut. Can't be cheaper so return junk. return CostSet { - costs: std::collections::HashMap::default(), + costs: Default::default(), total: INFINITY, choice: node_id.clone(), }; @@ -85,46 +86,35 @@ impl FasterGreedyDagExtractor { } } -impl FasterGreedyDagExtractor { - fn check(egraph: &EGraph, node_id: NodeId, costs: &HashMap) { - let cid = egraph.nid_to_cid(&node_id); - let previous = costs.get(cid).unwrap().total; - let cs = Self::calculate_cost_set(egraph, node_id, costs, INFINITY); - println!("{} {}", cs.total, previous); - assert!(cs.total >= previous); - } -} - impl Extractor for FasterGreedyDagExtractor { fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult { - // 1. build map from class to parent nodes - let mut parents = IndexMap::>::default(); + let mut parents = IndexMap::>::with_capacity(egraph.classes().len()); let n2c = |nid: &NodeId| egraph.nid_to_cid(nid); + let mut analysis_pending = UniqueQueue::default(); for class in egraph.classes().values() { parents.insert(class.id.clone(), Vec::new()); } + for class in egraph.classes().values() { for node in &class.nodes { for c in &egraph[node].children { + // compute parents of this enode parents[n2c(c)].push(node.clone()); } - } - } - // 2. start analysis from leaves - let mut analysis_pending = UniqueQueue::default(); - - for class in egraph.classes().values() { - for node in &class.nodes { + // start the analysis from leaves if egraph[node].is_leaf() { analysis_pending.insert(node.clone()); } } } - // 3. analyse from leaves towards parents until fixpoint - let mut costs = HashMap::::default(); + let mut result = ExtractionResult::default(); + let mut costs = FxHashMap::::with_capacity_and_hasher( + egraph.classes().len(), + Default::default(), + ); while let Some(node_id) = analysis_pending.pop() { let class_id = n2c(&node_id); @@ -144,15 +134,6 @@ impl Extractor for FasterGreedyDagExtractor { } } - /* - for class in egraph.classes().values() { - for node in &class.nodes { - Self::check(&egraph, node.clone(), &costs); - } - } - */ - - let mut result = ExtractionResult::default(); for (cid, cost_set) in costs { result.choose(cid, cost_set.choice); } @@ -164,6 +145,8 @@ impl Extractor for FasterGreedyDagExtractor { /** A data structure to maintain a queue of unique elements. Notably, insert/pop operations have O(1) expected amortized runtime complexity. + +Thanks @Bastacyclop for the implementation! */ #[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] @@ -171,7 +154,7 @@ pub(crate) struct UniqueQueue where T: Eq + std::hash::Hash + Clone, { - set: std::collections::HashSet, // hashbrown:: + set: FxHashSet, // hashbrown:: queue: std::collections::VecDeque, } @@ -181,7 +164,7 @@ where { fn default() -> Self { UniqueQueue { - set: std::collections::HashSet::default(), + set: Default::default(), queue: std::collections::VecDeque::new(), } } diff --git a/src/extract/greedy_dag.rs b/src/extract/greedy_dag.rs index c215916..5615181 100644 --- a/src/extract/greedy_dag.rs +++ b/src/extract/greedy_dag.rs @@ -10,10 +10,12 @@ struct CostSet { pub struct GreedyDagExtractor; impl Extractor for GreedyDagExtractor { fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult { - let mut costs = IndexMap::::default(); - let mut keep_going = true; + let mut costs = FxHashMap::::with_capacity_and_hasher( + egraph.classes().len(), + Default::default(), + ); - let mut nodes = egraph.nodes.clone(); + let mut keep_going = true; let mut i = 0; while keep_going { @@ -21,12 +23,10 @@ impl Extractor for GreedyDagExtractor { println!("iteration {}", i); keep_going = false; - let mut to_remove = vec![]; - - 'node_loop: for (node_id, node) in &nodes { + 'node_loop: for (node_id, node) in &egraph.nodes { let cid = egraph.nid_to_cid(node_id); let mut cost_set = CostSet { - costs: FxHashMap::default(), + costs: Default::default(), total: Cost::default(), choice: node_id.clone(), }; @@ -60,14 +60,6 @@ impl Extractor for GreedyDagExtractor { costs.insert(cid.clone(), cost_set); keep_going = true; } - to_remove.push(node_id.clone()); - } - - // removing nodes you've "done" can speed it up a lot but makes the results much worse - if false { - for node_id in to_remove { - nodes.remove(&node_id); - } } } diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 54a3753..413827e 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -6,10 +6,9 @@ pub use crate::*; pub mod bottom_up; pub mod faster_bottom_up; +pub mod faster_greedy_dag; pub mod global_greedy_dag; pub mod greedy_dag; -pub mod greedy_dag_1; - #[cfg(feature = "ilp-cbc")] pub mod ilp_cbc; diff --git a/src/main.rs b/src/main.rs index 19463cb..9c90a09 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,7 +22,7 @@ fn main() { ("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()), ( "faster-bottom-up", - extract::faster_bottom_up::BottomUpExtractor.boxed(), + extract::faster_bottom_up::FasterBottomUpExtractor.boxed(), ), ( "greedy-dag", @@ -30,7 +30,7 @@ fn main() { ), ( "faster-greedy-dag", - extract::greedy_dag_1::FasterGreedyDagExtractor.boxed(), + extract::faster_greedy_dag::FasterGreedyDagExtractor.boxed(), ), ( "global-greedy-dag", From 79f9eee3de483544ce00b87e961cec1d5656d7d2 Mon Sep 17 00:00:00 2001 From: Trevor Hansen Date: Fri, 15 Dec 2023 16:05:36 +1100 Subject: [PATCH 2/2] slight speedup when using the hashmap. --- src/extract/faster_greedy_dag.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/extract/faster_greedy_dag.rs b/src/extract/faster_greedy_dag.rs index b672404..1ec72bc 100644 --- a/src/extract/faster_greedy_dag.rs +++ b/src/extract/faster_greedy_dag.rs @@ -6,7 +6,8 @@ use super::*; use rustc_hash::{FxHashMap, FxHashSet}; struct CostSet { - costs: FxHashMap, + // It's slightly faster if this is an HashMap rather than an fxHashMap. + costs: HashMap, total: Cost, choice: NodeId, }