From 84d67bf2a5d361a4529eff932a059cbb58e3017f Mon Sep 17 00:00:00 2001 From: Trevor Hansen Date: Fri, 12 Jan 2024 11:02:10 +1100 Subject: [PATCH] ILP-based extractor that produces an optimal dag-cost extraction (#30) * better ilp extractor * Separate extractor with timeout and without timeout * fix build. enable ilp-cbc by default * fix to use ilp-cbc with timeout * Move the timeout to the caller * Removing cbc because I don't have permission to install cbc during github workflow * re-enable global-greedy-dag * add description of how cycles are blocked * Following the suggestion of @mwillsey to use floats rather than integers for the levels * fix layout --- Cargo.lock | 7 + src/extract/ilp_cbc.rs | 381 +++++++++++++++++------------------------ src/main.rs | 5 + 3 files changed, 173 insertions(+), 220 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 46f2e35..707fbcd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,6 +93,7 @@ dependencies = [ "ordered-float", "pico-args", "rpds", + "rustc-hash", ] [[package]] @@ -237,6 +238,12 @@ dependencies = [ "archery", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "ryu" version = "1.0.14" diff --git a/src/extract/ilp_cbc.rs b/src/extract/ilp_cbc.rs index c48b04b..82522bd 100644 --- a/src/extract/ilp_cbc.rs +++ b/src/extract/ilp_cbc.rs @@ -1,267 +1,208 @@ -use core::panic; +/* An ILP extractor that returns the optimal DAG-extraction. + +This extractor is simple so that it's easy to see that it's correct. + +If the timeout is reached, it will return the result of the faster-greedy-dag extractor. +*/ use super::*; use coin_cbc::{Col, Model, Sense}; use indexmap::IndexSet; -const INITIALISE_WITH_BOTTOM_UP: bool = false; - struct ClassVars { active: Col, nodes: Vec, } -pub struct CbcExtractor; +pub struct CbcExtractorWithTimeout; -impl Extractor for CbcExtractor { +impl Extractor for CbcExtractorWithTimeout { fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult { - let mut model = Model::default(); - - let true_literal = model.add_binary(); - model.set_col_lower(true_literal, 1.0); - - let vars: IndexMap = egraph - .classes() - .values() - .map(|class| { - let cvars = ClassVars { - active: if roots.contains(&class.id) { - // Roots must be active. - true_literal - } else { - model.add_binary() - }, - nodes: class.nodes.iter().map(|_| model.add_binary()).collect(), - }; - (class.id.clone(), cvars) - }) - .collect(); - - for (class_id, class) in &vars { - // class active == some node active - // sum(for node_active in class) == class_active - let row = model.add_row(); - model.set_row_equal(row, 0.0); - model.set_weight(row, class.active, -1.0); - for &node_active in &class.nodes { - model.set_weight(row, node_active, 1.0); - } + return extract(egraph, roots, TIMEOUT_IN_SECONDS); + } +} - let childrens_classes_var = |nid: NodeId| { - egraph[&nid] - .children - .iter() - .map(|n| egraph[n].eclass.clone()) - .map(|n| vars[&n].active) - .collect::>() - }; +pub struct CbcExtractor; - let mut intersection: IndexSet = - childrens_classes_var(egraph[class_id].nodes[0].clone()); +impl Extractor for CbcExtractor { + fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult { + return extract(egraph, roots, std::u32::MAX); + } +} - for node in &egraph[class_id].nodes[1..] { - intersection = intersection - .intersection(&childrens_classes_var(node.clone())) - .cloned() - .collect(); - } +fn extract(egraph: &EGraph, roots: &[ClassId], timeout_seconds: u32) -> ExtractionResult { + let mut model = Model::default(); - // A class being active implies that all in the intersection - // of it's children are too. - for c in &intersection { - let row = model.add_row(); - model.set_row_upper(row, 0.0); - model.set_weight(row, class.active, 1.0); - model.set_weight(row, *c, -1.0); - } + model.set_parameter("seconds", &timeout_seconds.to_string()); - for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) { - for child_active in childrens_classes_var(node_id.clone()) { - // node active implies child active, encoded as: - // node_active <= child_active - // node_active - child_active <= 0 - if !intersection.contains(&child_active) { - let row = model.add_row(); - model.set_row_upper(row, 0.0); - model.set_weight(row, node_active, 1.0); - model.set_weight(row, child_active, -1.0); - } - } - } + let vars: IndexMap = egraph + .classes() + .values() + .map(|class| { + let cvars = ClassVars { + active: model.add_binary(), + nodes: class.nodes.iter().map(|_| model.add_binary()).collect(), + }; + (class.id.clone(), cvars) + }) + .collect(); + + for (class_id, class) in &vars { + // class active == some node active + // sum(for node_active in class) == class_active + let row = model.add_row(); + model.set_row_equal(row, 0.0); + model.set_weight(row, class.active, -1.0); + for &node_active in &class.nodes { + model.set_weight(row, node_active, 1.0); } - model.set_obj_sense(Sense::Minimize); - for class in egraph.classes().values() { - let min_cost = class - .nodes + let childrens_classes_var = |nid: NodeId| { + egraph[&nid] + .children .iter() - .map(|n_id| egraph[n_id].cost) - .min() - .unwrap_or(Cost::default()) - .into_inner(); - - // Most helpful when the members of the class all have the same cost. - // For example if the members' costs are [1,1,1], three terms get - // replaced by one in the objective function. - if min_cost != 0.0 { - model.set_obj_coeff(vars[&class.id].active, min_cost); - } - - for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) { - let node = &egraph[node_id]; - let node_cost = node.cost.into_inner() - min_cost; - assert!(node_cost >= 0.0); - - if node_cost != 0.0 { - model.set_obj_coeff(node_active, node_cost); - } + .map(|n| egraph[n].eclass.clone()) + .map(|n| vars[&n].active) + .collect::>() + }; + + for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) { + for child_active in childrens_classes_var(node_id.clone()) { + // node active implies child active, encoded as: + // node_active <= child_active + // node_active - child_active <= 0 + let row = model.add_row(); + model.set_row_upper(row, 0.0); + model.set_weight(row, node_active, 1.0); + model.set_weight(row, child_active, -1.0); } } + } - // set initial solution based on bottom up extractor - if INITIALISE_WITH_BOTTOM_UP { - let initial_result = super::bottom_up::BottomUpExtractor.extract(egraph, roots); - for (class, class_vars) in egraph.classes().values().zip(vars.values()) { - if let Some(node_id) = initial_result.choices.get(&class.id) { - model.set_col_initial_solution(class_vars.active, 1.0); - for col in &class_vars.nodes { - model.set_col_initial_solution(*col, 0.0); - } - let node_idx = class.nodes.iter().position(|n| n == node_id).unwrap(); - model.set_col_initial_solution(class_vars.nodes[node_idx], 1.0); - } else { - model.set_col_initial_solution(class_vars.active, 0.0); - } - } - } + model.set_obj_sense(Sense::Minimize); + for class in egraph.classes().values() { + for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) { + let node = &egraph[node_id]; + let node_cost = node.cost.into_inner(); + assert!(node_cost >= 0.0); - let mut banned_cycles: IndexSet<(ClassId, usize)> = Default::default(); - find_cycles(egraph, |id, i| { - banned_cycles.insert((id, i)); - }); - for (class_id, class_vars) in &vars { - for (i, &node_active) in class_vars.nodes.iter().enumerate() { - if banned_cycles.contains(&(class_id.clone(), i)) { - model.set_col_upper(node_active, 0.0); - model.set_col_lower(node_active, 0.0); - } - } - } - log::info!("@blocked {}", banned_cycles.len()); - - let solution = model.solve(); - log::info!( - "CBC status {:?}, {:?}, obj = {}", - solution.raw().status(), - solution.raw().secondary_status(), - solution.raw().obj_value(), - ); - - let mut result = ExtractionResult::default(); - - for (id, var) in &vars { - let active = solution.col(var.active) > 0.0; - if active { - let node_idx = var - .nodes - .iter() - .position(|&n| solution.col(n) > 0.0) - .unwrap(); - let node_id = egraph[id].nodes[node_idx].clone(); - result.choose(id.clone(), node_id); + if node_cost != 0.0 { + model.set_obj_coeff(node_active, node_cost); } } - - let cycles = result.find_cycles(egraph, roots); - assert!(cycles.is_empty()); - return result; } -} -// from @khaki3 -// fixes bug in egg 0.9.4's version -// https://github.com/egraphs-good/egg/issues/207#issuecomment-1264737441 -fn find_cycles(egraph: &EGraph, mut f: impl FnMut(ClassId, usize)) { - let mut pending: IndexMap> = IndexMap::default(); + for root in roots { + model.set_col_lower(vars[root].active, 1.0); + } - let mut order: IndexMap = IndexMap::default(); + block_cycles(&mut model, &vars, &egraph); - let mut memo: IndexMap<(ClassId, usize), bool> = IndexMap::default(); + let solution = model.solve(); + log::info!( + "CBC status {:?}, {:?}, obj = {}", + solution.raw().status(), + solution.raw().secondary_status(), + solution.raw().obj_value(), + ); - let mut stack: Vec<(ClassId, usize)> = vec![]; + if solution.raw().status() != coin_cbc::raw::Status::Finished { + assert!(timeout_seconds != std::u32::MAX); - let n2c = |nid: &NodeId| egraph.nid_to_cid(nid); + let initial_result = + super::faster_greedy_dag::FasterGreedyDagExtractor.extract(egraph, roots); + log::info!("Unfinished CBC solution"); + return initial_result; + } - for class in egraph.classes().values() { - let id = &class.id; - for (i, node_id) in egraph[id].nodes.iter().enumerate() { - let node = &egraph[node_id]; - for child in &node.children { - let child = n2c(child).clone(); - pending - .entry(child) - .or_insert_with(Vec::new) - .push((id.clone(), i)); - } + let mut result = ExtractionResult::default(); - if node.is_leaf() { - stack.push((id.clone(), i)); - } + for (id, var) in &vars { + let active = solution.col(var.active) > 0.0; + if active { + let node_idx = var + .nodes + .iter() + .position(|&n| solution.col(n) > 0.0) + .unwrap(); + let node_id = egraph[id].nodes[node_idx].clone(); + result.choose(id.clone(), node_id); } } - let mut count = 0; - - while let Some((id, i)) = stack.pop() { - if memo.get(&(id.clone(), i)).is_some() { - continue; - } + return result; +} - let node_id = &egraph[&id].nodes[i]; - let node = &egraph[node_id]; - let mut update = false; - - if node.is_leaf() { - update = true; - } else if node.children.iter().all(|x| order.get(n2c(x)).is_some()) { - if let Some(ord) = order.get(&id) { - update = node.children.iter().all(|x| &order[n2c(x)] < ord); - if !update { - memo.insert((id, i), false); - continue; - } - } else { - update = true; - } - } +/* + + To block cycles, we enforce that a topological ordering exists on the extraction. + Each class is mapped to a variable (called its level). Then for each node, + we add a constraint that if a node is active, then the level of the class the node + belongs to must be less than than the level of each of the node's children. + + To create a cycle, the levels would need to decrease, so they're blocked. For example, + given a two class cycle: if class A, has level 'l', and class B has level 'm', then + 'l' must be less than 'm', but because there is also an active node in class B that + has class A as a child, 'm' must be less than 'l', which is a contradiction. +*/ + +fn block_cycles(model: &mut Model, vars: &IndexMap, egraph: &EGraph) { + let mut levels: IndexMap = Default::default(); + for c in vars.keys() { + let var = model.add_col(); + levels.insert(c.clone(), var); + //model.set_col_lower(var, 0.0); + // It solves the benchmarks about 5% faster without this + //model.set_col_upper(var, vars.len() as f64); + } - if update { - if order.get(&id).is_none() { - if egraph[node_id].is_leaf() { - order.insert(id.clone(), 0); - } else { - order.insert(id.clone(), count); - count += 1; - } - } - memo.insert((id.clone(), i), true); - if let Some(mut v) = pending.remove(&id) { - stack.append(&mut v); - stack.sort(); - stack.dedup(); - }; + // If n.variable is true, opposite_col will be false and vice versa. + let mut opposite: IndexMap = Default::default(); + for c in vars.values() { + for n in &c.nodes { + let opposite_col = model.add_binary(); + opposite.insert(*n, opposite_col); + let row = model.add_row(); + model.set_row_equal(row, 1.0); + model.set_weight(row, opposite_col, 1.0); + model.set_weight(row, *n, 1.0); } } - for class in egraph.classes().values() { - let id = &class.id; - for (i, node) in class.nodes.iter().enumerate() { - if let Some(true) = memo.get(&(id.clone(), i)) { + for (class_id, c) in vars { + for i in 0..c.nodes.len() { + let n_id = &egraph[class_id].nodes[i]; + let n = &egraph[n_id]; + let var = c.nodes[i]; + + let children_classes = n + .children + .iter() + .map(|n| egraph[n].eclass.clone()) + .collect::>(); + + if children_classes.contains(class_id) { + // Self loop - disable this node. + // This is clumsier than calling set_col_lower(var,0.0), + // but means it'll be infeasible (rather than producing an + // incorrect solution) if var corresponds to a root node. + let row = model.add_row(); + model.set_weight(row, var, 1.0); + model.set_row_equal(row, 0.0); continue; } - assert!(!egraph[node].is_leaf()); - f(id.clone(), i); + + for cc in children_classes { + assert!(*levels.get(class_id).unwrap() != *levels.get(&cc).unwrap()); + + let row = model.add_row(); + model.set_row_lower(row, 1.0); + model.set_weight(row, *levels.get(class_id).unwrap(), -1.0); + model.set_weight(row, *levels.get(&cc).unwrap(), 1.0); + + // If n.variable is 0, then disable the contraint. + model.set_weight(row, *opposite.get(&var).unwrap(), (vars.len() + 1) as f64); + } } } - assert!(pending.is_empty()); } diff --git a/src/main.rs b/src/main.rs index bddc552..fab1624 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,11 @@ fn main() { "global-greedy-dag", extract::global_greedy_dag::GlobalGreedyDagExtractor.boxed(), ), + #[cfg(feature = "ilp-cbc")] + ( + "ilp-cbc-timeout", + extract::ilp_cbc::CbcExtractorWithTimeout::<10>.boxed(), + ), ] .into_iter() .collect();