Skip to content

Commit

Permalink
Clean up as well as speed up the ILP extractor. (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorHansen authored Nov 2, 2023
1 parent f168a55 commit 5d433b0
Showing 1 changed file with 125 additions and 103 deletions.
228 changes: 125 additions & 103 deletions src/extract/ilp_cbc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,40 @@ use super::*;
use coin_cbc::{Col, Model, Sense};
use indexmap::IndexSet;

const INITIALISE_WITH_BOTTOM_UP: bool = false;

struct ClassVars {
active: Col,
order: Col,
nodes: Vec<Col>,
}

pub struct CbcExtractor;

impl Extractor for CbcExtractor {
fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult {
let max_order = egraph.nodes.len() as f64 * 10.0;

let mut model = Model::default();
// model.set_parameter("seconds", "30");
// model.set_parameter("allowableGap", "100000000");

let true_literal = model.add_binary();
model.set_col_lower(true_literal, 1.0);

let vars: IndexMap<ClassId, ClassVars> = egraph
.classes()
.values()
.map(|class| {
let cvars = ClassVars {
active: model.add_binary(),
order: model.add_col(),
active: if roots.contains(&class.id) {
// Roots must be active.
true_literal
} else {
model.add_binary()
},
nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
};
model.set_col_upper(cvars.order, max_order);
(class.id.clone(), cvars)
})
.collect();

for (id, class) in &vars {
for (class_id, class) in &vars {
// class active == some node active
// sum(for node_active in class) == class_active
let row = model.add_row();
Expand All @@ -44,121 +47,134 @@ impl Extractor for CbcExtractor {
model.set_weight(row, node_active, 1.0);
}

for (node_id, &node_active) in egraph[id].nodes.iter().zip(&class.nodes) {
let node = &egraph[node_id];
for child in &node.children {
let eclass_id = &egraph[child].eclass;
let child_active = vars[eclass_id].active;
let childrens_classes_var = |nid: NodeId| {
egraph[&nid]
.children
.iter()
.map(|n| egraph[n].eclass.clone())
.map(|n| vars[&n].active)
.collect::<IndexSet<_>>()
};

let mut intersection: IndexSet<Col> =
childrens_classes_var(egraph[class_id].nodes[0].clone());

for node in &egraph[class_id].nodes[1..] {
intersection = intersection
.intersection(&childrens_classes_var(node.clone()))
.cloned()
.collect();
}

// A class being active implies that all in the intersection
// of it's children are too.
for c in &intersection {
let row = model.add_row();
model.set_row_upper(row, 0.0);
model.set_weight(row, class.active, 1.0);
model.set_weight(row, *c, -1.0);
}

for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) {
for child_active in childrens_classes_var(node_id.clone()) {
// node active implies child active, encoded as:
// node_active <= child_active
// node_active - child_active <= 0
let row = model.add_row();
model.set_row_upper(row, 0.0);
model.set_weight(row, node_active, 1.0);
model.set_weight(row, child_active, -1.0);
if !intersection.contains(&child_active) {
let row = model.add_row();
model.set_row_upper(row, 0.0);
model.set_weight(row, node_active, 1.0);
model.set_weight(row, child_active, -1.0);
}
}
}
}

model.set_obj_sense(Sense::Minimize);
for class in egraph.classes().values() {
for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) {
let node = &egraph[node_id];
model.set_obj_coeff(node_active, node.cost.into_inner());
let min_cost = class
.nodes
.iter()
.map(|n_id| egraph[n_id].cost)
.min()
.unwrap_or(Cost::default())
.into_inner();

// Most helpful when the members of the class all have the same cost.
// For example if the members' costs are [1,1,1], three terms get
// replaced by one in the objective function.
if min_cost != 0.0 {
model.set_obj_coeff(vars[&class.id].active, min_cost);
}
}

// model is now ready to go, time to solve
dbg!(max_order);

for class in vars.values() {
model.set_binary(class.active);
}

for root in roots {
model.set_col_lower(vars[root].active, 1.0);
}
for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) {
let node = &egraph[node_id];
let node_cost = node.cost.into_inner() - min_cost;
assert!(node_cost >= 0.0);

// set initial solution based on bottom up extractor
let initial_result = super::bottom_up::BottomUpExtractor.extract(egraph, roots);
for (class, class_vars) in egraph.classes().values().zip(vars.values()) {
if let Some(node_id) = initial_result.choices.get(&class.id) {
model.set_col_initial_solution(class_vars.active, 1.0);
for col in &class_vars.nodes {
model.set_col_initial_solution(*col, 0.0);
if node_cost != 0.0 {
model.set_obj_coeff(node_active, node_cost);
}
let node_idx = class.nodes.iter().position(|n| n == node_id).unwrap();
model.set_col_initial_solution(class_vars.nodes[node_idx], 1.0);
} else {
model.set_col_initial_solution(class_vars.active, 0.0);
}
}

let mut banned_cycles: IndexSet<(ClassId, usize)> = Default::default();
// find_cycles(egraph, |id, i| {
// banned_cycles.insert((id, i));
// });

for iteration in 0.. {
if iteration == 0 {
find_cycles(egraph, |id, i| {
banned_cycles.insert((id, i));
});
} else if iteration >= 2 {
panic!("Too many iterations");
}

for (id, class) in &vars {
for (i, (_node, &node_active)) in
egraph[id].nodes.iter().zip(&class.nodes).enumerate()
{
if banned_cycles.contains(&(id.clone(), i)) {
model.set_col_upper(node_active, 0.0);
model.set_col_lower(node_active, 0.0);
// set initial solution based on bottom up extractor
if INITIALISE_WITH_BOTTOM_UP {
let initial_result = super::bottom_up::BottomUpExtractor.extract(egraph, roots);
for (class, class_vars) in egraph.classes().values().zip(vars.values()) {
if let Some(node_id) = initial_result.choices.get(&class.id) {
model.set_col_initial_solution(class_vars.active, 1.0);
for col in &class_vars.nodes {
model.set_col_initial_solution(*col, 0.0);
}
let node_idx = class.nodes.iter().position(|n| n == node_id).unwrap();
model.set_col_initial_solution(class_vars.nodes[node_idx], 1.0);
} else {
model.set_col_initial_solution(class_vars.active, 0.0);
}
}
}

let solution = model.solve();
log::info!(
"CBC status {:?}, {:?}, obj = {}",
solution.raw().status(),
solution.raw().secondary_status(),
solution.raw().obj_value(),
);

let mut result = ExtractionResult::default();

for (id, var) in &vars {
let active = solution.col(var.active) > 0.0;
if active {
let node_idx = var
.nodes
.iter()
.position(|&n| solution.col(n) > 0.0)
.unwrap();
let node_id = egraph[id].nodes[node_idx].clone();
result.choose(id.clone(), node_id);
let mut banned_cycles: IndexSet<(ClassId, usize)> = Default::default();
find_cycles(egraph, |id, i| {
banned_cycles.insert((id, i));
});
for (class_id, class_vars) in &vars {
for (i, &node_active) in class_vars.nodes.iter().enumerate() {
if banned_cycles.contains(&(class_id.clone(), i)) {
model.set_col_upper(node_active, 0.0);
model.set_col_lower(node_active, 0.0);
}
}

let cycles = result.find_cycles(egraph, roots);
if cycles.is_empty() {
return result;
} else {
log::info!("Found {} cycles", cycles.len());
// for id in cycles {
// let class = &vars[&id];
// let node_idx = class
// .nodes
// .iter()
// .position(|&n| solution.col(n) > 0.0)
// .unwrap();
// banned_cycles.insert((id, node_idx));
// }
}
log::info!("@blocked {}", banned_cycles.len());

let solution = model.solve();
log::info!(
"CBC status {:?}, {:?}, obj = {}",
solution.raw().status(),
solution.raw().secondary_status(),
solution.raw().obj_value(),
);

let mut result = ExtractionResult::default();

for (id, var) in &vars {
let active = solution.col(var.active) > 0.0;
if active {
let node_idx = var
.nodes
.iter()
.position(|&n| solution.col(n) > 0.0)
.unwrap();
let node_id = egraph[id].nodes[node_idx].clone();
result.choose(id.clone(), node_id);
}
}
unreachable!()

let cycles = result.find_cycles(egraph, roots);
assert!(cycles.is_empty());
return result;
}
}

Expand Down Expand Up @@ -221,8 +237,12 @@ fn find_cycles(egraph: &EGraph, mut f: impl FnMut(ClassId, usize)) {

if update {
if order.get(&id).is_none() {
order.insert(id.clone(), count);
count += 1;
if egraph[node_id].is_leaf() {
order.insert(id.clone(), 0);
} else {
order.insert(id.clone(), count);
count += 1;
}
}
memo.insert((id.clone(), i), true);
if let Some(mut v) = pending.remove(&id) {
Expand All @@ -235,11 +255,13 @@ fn find_cycles(egraph: &EGraph, mut f: impl FnMut(ClassId, usize)) {

for class in egraph.classes().values() {
let id = &class.id;
for (i, _node) in class.nodes.iter().enumerate() {
for (i, node) in class.nodes.iter().enumerate() {
if let Some(true) = memo.get(&(id.clone(), i)) {
continue;
}
assert!(!egraph[node].is_leaf());
f(id.clone(), i);
}
}
assert!(pending.is_empty());
}

0 comments on commit 5d433b0

Please sign in to comment.