diff --git a/src/extract/ilp_cbc.rs b/src/extract/ilp_cbc.rs
index ed40e09..c48b04b 100644
--- a/src/extract/ilp_cbc.rs
+++ b/src/extract/ilp_cbc.rs
@@ -4,9 +4,10 @@ use super::*;
use coin_cbc::{Col, Model, Sense};
use indexmap::IndexSet;
+const INITIALISE_WITH_BOTTOM_UP: bool = false;
+
struct ClassVars {
active: Col,
- order: Col,
nodes: Vec
,
}
@@ -14,27 +15,29 @@ pub struct CbcExtractor;
impl Extractor for CbcExtractor {
fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult {
- let max_order = egraph.nodes.len() as f64 * 10.0;
-
let mut model = Model::default();
- // model.set_parameter("seconds", "30");
- // model.set_parameter("allowableGap", "100000000");
+
+ let true_literal = model.add_binary();
+ model.set_col_lower(true_literal, 1.0);
let vars: IndexMap = egraph
.classes()
.values()
.map(|class| {
let cvars = ClassVars {
- active: model.add_binary(),
- order: model.add_col(),
+ active: if roots.contains(&class.id) {
+ // Roots must be active.
+ true_literal
+ } else {
+ model.add_binary()
+ },
nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
};
- model.set_col_upper(cvars.order, max_order);
(class.id.clone(), cvars)
})
.collect();
- for (id, class) in &vars {
+ for (class_id, class) in &vars {
// class active == some node active
// sum(for node_active in class) == class_active
let row = model.add_row();
@@ -44,121 +47,134 @@ impl Extractor for CbcExtractor {
model.set_weight(row, node_active, 1.0);
}
- for (node_id, &node_active) in egraph[id].nodes.iter().zip(&class.nodes) {
- let node = &egraph[node_id];
- for child in &node.children {
- let eclass_id = &egraph[child].eclass;
- let child_active = vars[eclass_id].active;
+ let childrens_classes_var = |nid: NodeId| {
+ egraph[&nid]
+ .children
+ .iter()
+ .map(|n| egraph[n].eclass.clone())
+ .map(|n| vars[&n].active)
+ .collect::>()
+ };
+
+ let mut intersection: IndexSet =
+ childrens_classes_var(egraph[class_id].nodes[0].clone());
+
+ for node in &egraph[class_id].nodes[1..] {
+ intersection = intersection
+ .intersection(&childrens_classes_var(node.clone()))
+ .cloned()
+ .collect();
+ }
+
+ // A class being active implies that all in the intersection
+ // of it's children are too.
+ for c in &intersection {
+ let row = model.add_row();
+ model.set_row_upper(row, 0.0);
+ model.set_weight(row, class.active, 1.0);
+ model.set_weight(row, *c, -1.0);
+ }
+
+ for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) {
+ for child_active in childrens_classes_var(node_id.clone()) {
// node active implies child active, encoded as:
// node_active <= child_active
// node_active - child_active <= 0
- let row = model.add_row();
- model.set_row_upper(row, 0.0);
- model.set_weight(row, node_active, 1.0);
- model.set_weight(row, child_active, -1.0);
+ if !intersection.contains(&child_active) {
+ let row = model.add_row();
+ model.set_row_upper(row, 0.0);
+ model.set_weight(row, node_active, 1.0);
+ model.set_weight(row, child_active, -1.0);
+ }
}
}
}
model.set_obj_sense(Sense::Minimize);
for class in egraph.classes().values() {
- for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) {
- let node = &egraph[node_id];
- model.set_obj_coeff(node_active, node.cost.into_inner());
+ let min_cost = class
+ .nodes
+ .iter()
+ .map(|n_id| egraph[n_id].cost)
+ .min()
+ .unwrap_or(Cost::default())
+ .into_inner();
+
+ // Most helpful when the members of the class all have the same cost.
+ // For example if the members' costs are [1,1,1], three terms get
+ // replaced by one in the objective function.
+ if min_cost != 0.0 {
+ model.set_obj_coeff(vars[&class.id].active, min_cost);
}
- }
-
- // model is now ready to go, time to solve
- dbg!(max_order);
-
- for class in vars.values() {
- model.set_binary(class.active);
- }
- for root in roots {
- model.set_col_lower(vars[root].active, 1.0);
- }
+ for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) {
+ let node = &egraph[node_id];
+ let node_cost = node.cost.into_inner() - min_cost;
+ assert!(node_cost >= 0.0);
- // set initial solution based on bottom up extractor
- let initial_result = super::bottom_up::BottomUpExtractor.extract(egraph, roots);
- for (class, class_vars) in egraph.classes().values().zip(vars.values()) {
- if let Some(node_id) = initial_result.choices.get(&class.id) {
- model.set_col_initial_solution(class_vars.active, 1.0);
- for col in &class_vars.nodes {
- model.set_col_initial_solution(*col, 0.0);
+ if node_cost != 0.0 {
+ model.set_obj_coeff(node_active, node_cost);
}
- let node_idx = class.nodes.iter().position(|n| n == node_id).unwrap();
- model.set_col_initial_solution(class_vars.nodes[node_idx], 1.0);
- } else {
- model.set_col_initial_solution(class_vars.active, 0.0);
}
}
- let mut banned_cycles: IndexSet<(ClassId, usize)> = Default::default();
- // find_cycles(egraph, |id, i| {
- // banned_cycles.insert((id, i));
- // });
-
- for iteration in 0.. {
- if iteration == 0 {
- find_cycles(egraph, |id, i| {
- banned_cycles.insert((id, i));
- });
- } else if iteration >= 2 {
- panic!("Too many iterations");
- }
-
- for (id, class) in &vars {
- for (i, (_node, &node_active)) in
- egraph[id].nodes.iter().zip(&class.nodes).enumerate()
- {
- if banned_cycles.contains(&(id.clone(), i)) {
- model.set_col_upper(node_active, 0.0);
- model.set_col_lower(node_active, 0.0);
+ // set initial solution based on bottom up extractor
+ if INITIALISE_WITH_BOTTOM_UP {
+ let initial_result = super::bottom_up::BottomUpExtractor.extract(egraph, roots);
+ for (class, class_vars) in egraph.classes().values().zip(vars.values()) {
+ if let Some(node_id) = initial_result.choices.get(&class.id) {
+ model.set_col_initial_solution(class_vars.active, 1.0);
+ for col in &class_vars.nodes {
+ model.set_col_initial_solution(*col, 0.0);
}
+ let node_idx = class.nodes.iter().position(|n| n == node_id).unwrap();
+ model.set_col_initial_solution(class_vars.nodes[node_idx], 1.0);
+ } else {
+ model.set_col_initial_solution(class_vars.active, 0.0);
}
}
+ }
- let solution = model.solve();
- log::info!(
- "CBC status {:?}, {:?}, obj = {}",
- solution.raw().status(),
- solution.raw().secondary_status(),
- solution.raw().obj_value(),
- );
-
- let mut result = ExtractionResult::default();
-
- for (id, var) in &vars {
- let active = solution.col(var.active) > 0.0;
- if active {
- let node_idx = var
- .nodes
- .iter()
- .position(|&n| solution.col(n) > 0.0)
- .unwrap();
- let node_id = egraph[id].nodes[node_idx].clone();
- result.choose(id.clone(), node_id);
+ let mut banned_cycles: IndexSet<(ClassId, usize)> = Default::default();
+ find_cycles(egraph, |id, i| {
+ banned_cycles.insert((id, i));
+ });
+ for (class_id, class_vars) in &vars {
+ for (i, &node_active) in class_vars.nodes.iter().enumerate() {
+ if banned_cycles.contains(&(class_id.clone(), i)) {
+ model.set_col_upper(node_active, 0.0);
+ model.set_col_lower(node_active, 0.0);
}
}
-
- let cycles = result.find_cycles(egraph, roots);
- if cycles.is_empty() {
- return result;
- } else {
- log::info!("Found {} cycles", cycles.len());
- // for id in cycles {
- // let class = &vars[&id];
- // let node_idx = class
- // .nodes
- // .iter()
- // .position(|&n| solution.col(n) > 0.0)
- // .unwrap();
- // banned_cycles.insert((id, node_idx));
- // }
+ }
+ log::info!("@blocked {}", banned_cycles.len());
+
+ let solution = model.solve();
+ log::info!(
+ "CBC status {:?}, {:?}, obj = {}",
+ solution.raw().status(),
+ solution.raw().secondary_status(),
+ solution.raw().obj_value(),
+ );
+
+ let mut result = ExtractionResult::default();
+
+ for (id, var) in &vars {
+ let active = solution.col(var.active) > 0.0;
+ if active {
+ let node_idx = var
+ .nodes
+ .iter()
+ .position(|&n| solution.col(n) > 0.0)
+ .unwrap();
+ let node_id = egraph[id].nodes[node_idx].clone();
+ result.choose(id.clone(), node_id);
}
}
- unreachable!()
+
+ let cycles = result.find_cycles(egraph, roots);
+ assert!(cycles.is_empty());
+ return result;
}
}
@@ -221,8 +237,12 @@ fn find_cycles(egraph: &EGraph, mut f: impl FnMut(ClassId, usize)) {
if update {
if order.get(&id).is_none() {
- order.insert(id.clone(), count);
- count += 1;
+ if egraph[node_id].is_leaf() {
+ order.insert(id.clone(), 0);
+ } else {
+ order.insert(id.clone(), count);
+ count += 1;
+ }
}
memo.insert((id.clone(), i), true);
if let Some(mut v) = pending.remove(&id) {
@@ -235,11 +255,13 @@ fn find_cycles(egraph: &EGraph, mut f: impl FnMut(ClassId, usize)) {
for class in egraph.classes().values() {
let id = &class.id;
- for (i, _node) in class.nodes.iter().enumerate() {
+ for (i, node) in class.nodes.iter().enumerate() {
if let Some(true) = memo.get(&(id.clone(), i)) {
continue;
}
+ assert!(!egraph[node].is_leaf());
f(id.clone(), i);
}
}
+ assert!(pending.is_empty());
}