Skip to content

Commit

Permalink
add cost_rec and make extractor receive &AppliedId
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryleak47 committed Oct 28, 2024
1 parent f31bb32 commit da54faa
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 19 deletions.
10 changes: 10 additions & 0 deletions src/extract/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ use std::marker::PhantomData;
pub trait CostFunction<L: Language> {
type Cost: Ord + Clone + Debug;
fn cost<C>(&self, enode: &L, costs: C) -> Self::Cost where C: Fn(Id) -> Self::Cost;

fn cost_rec(&self, expr: &RecExpr<L>) -> Self::Cost {
let child_costs: Vec<Self::Cost> = expr.children.iter().map(|x| self.cost_rec(x)).collect();
let c = |i: Id| child_costs[i.0].clone();
let mut node = expr.node.clone();
for (i, x) in node.applied_id_occurences_mut().iter_mut().enumerate() {
**x = AppliedId::new(Id(i), SlotMap::new());
}
self.cost(&node, c)
}
}

/// The 'default' [CostFunction]. It measures the size of the abstract syntax tree of the corresponding term.
Expand Down
19 changes: 13 additions & 6 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ impl<L: Language, CF: CostFunction<L>> Extractor<L, CF> {
Self { map }
}

pub fn extract<N: Analysis<L>>(&self, i: AppliedId, eg: &EGraph<L, N>) -> RecExpr<L> {
let i = eg.find_applied_id(&i);
pub fn extract<N: Analysis<L>>(&self, i: &AppliedId, eg: &EGraph<L, N>) -> RecExpr<L> {
let i = eg.find_applied_id(i);

let mut children = Vec::new();

// do I need to refresh some slots here?
let l = self.map[&i.id].0.apply_slotmap(&i.m);
for child in l.applied_id_occurences() {
let n = self.extract(child, eg);
let n = self.extract(&child, eg);
children.push(n);
}

Expand All @@ -81,12 +81,19 @@ impl<L: Language, CF: CostFunction<L>> Extractor<L, CF> {
}
}

pub fn ast_size_extract<L: Language, N: Analysis<L>>(i: AppliedId, eg: &EGraph<L, N>) -> RecExpr<L> {
pub fn ast_size_extract<L: Language, N: Analysis<L>>(i: &AppliedId, eg: &EGraph<L, N>) -> RecExpr<L> {
extract::<L, N, AstSize>(i, eg)
}

// `i` is not allowed to have free variables, hence prefer `Id` over `AppliedId`.
pub fn extract<L: Language, N: Analysis<L>, CF: CostFunction<L> + Default>(i: AppliedId, eg: &EGraph<L, N>) -> RecExpr<L> {
pub fn extract<L: Language, N: Analysis<L>, CF: CostFunction<L> + Default>(i: &AppliedId, eg: &EGraph<L, N>) -> RecExpr<L> {
let cost_fn = CF::default();
Extractor::<L, CF>::new(eg, cost_fn).extract(i, eg)
let extractor = Extractor::<L, CF>::new(eg, cost_fn);
let out = extractor.extract(&i, eg);
if CHECKS {
let i = eg.find_id(i.id);
let cost_fn = CF::default();
assert_eq!(cost_fn.cost_rec(&out), extractor.map[&i].1);
}
out
}
2 changes: 1 addition & 1 deletion src/rewrite/subst_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl<L: Language, N: Analysis<L>> SubstMethod<L, N> for ExtractionSubst {
}

fn subst(&mut self, b: AppliedId, x: AppliedId, t: AppliedId, eg: &mut EGraph<L, N>) -> AppliedId {
let term = ast_size_extract::<L, N>(b, eg);
let term = ast_size_extract::<L, N>(&b, eg);
do_term_subst(eg, &term, &x, &t)
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/array/tst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn normalize(re: RecExpr<Array>) -> RecExpr<Array> {
for _ in 0..40 {
apply_rewrites(&mut eg, &rules);
}
extract::<_, _, AstSizeNoLet>(i, &eg)
extract::<_, _, AstSizeNoLet>(&i, &eg)
}

fn assert_reaches(start: &str, goal: &str, steps: usize, rules: &[&'static str]) {
Expand All @@ -32,7 +32,7 @@ fn assert_reaches(start: &str, goal: &str, steps: usize, rules: &[&'static str])
}
}

dbg!(extract::<_, _, AstSizeNoLet>(i1, &eg));
dbg!(extract::<_, _, AstSizeNoLet>(&i1, &eg));
dbg!(&goal);
assert!(false);
}
Expand Down
6 changes: 3 additions & 3 deletions tests/lambda/realization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn simplify_to_nf<R: Realization>(s: &str) -> String {
for _ in 0..NO_ITERS {
R::step(&mut eg);

re = extract_ast(&eg, i.clone());
re = extract_ast(&eg, &i);
if lam_step(&re).is_none() {
#[cfg(feature = "explanations")]
eg.explain_equivalence(orig_re, re.clone());
Expand All @@ -41,7 +41,7 @@ pub fn simplify<R: Realization>(s: &str) -> String {
break;
}
}
let out = extract_ast(&eg, i.clone());
let out = extract_ast(&eg, &i);

#[cfg(feature = "explanations")]
eg.explain_equivalence(re.clone(), out.clone());
Expand Down Expand Up @@ -95,7 +95,7 @@ pub fn check_eq<R: Realization>(s1: &str, s2: &str) {

// Non-Realization functions:

fn extract_ast(eg: &EGraph<Lambda>, i: AppliedId) -> RecExpr<Lambda> {
fn extract_ast(eg: &EGraph<Lambda>, i: &AppliedId) -> RecExpr<Lambda> {
extract::<_, _, AstSizeNoLet>(i, eg)
}

Expand Down
8 changes: 4 additions & 4 deletions tests/rise/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ fn beta_extr() -> Rewrite<Rise> {

let mut out: Vec<(Subst, RecExpr<Rise>)> = Vec::new();
for subst in ematch_all(eg, &a) {
let b = extractor.extract(subst["b"].clone(), eg);
let t = extractor.extract(subst["t"].clone(), eg);
let b = extractor.extract(&subst["b"], eg);
let t = extractor.extract(&subst["t"], eg);
let res = re_subst(s, b, &t);
out.push((subst, res));
}
Expand Down Expand Up @@ -266,8 +266,8 @@ fn beta_extr_direct() -> Rewrite<Rise> {

let mut out: Vec<(Subst, RecExpr<Rise>)> = Vec::new();
for subst in ematch_all(eg, &a) {
let b = extractor.extract(subst["b"].clone(), eg);
let t = extractor.extract(subst["t"].clone(), eg);
let b = extractor.extract(&subst["b"], eg);
let t = extractor.extract(&subst["t"], eg);
let res = re_subst(s, b, &t);
out.push((subst, res));
}
Expand Down
2 changes: 1 addition & 1 deletion tests/rise/tst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn assert_reaches(start: &str, goal: &str, steps: usize) {
}
}

dbg!(extract::<_, _, AstSizeNoLet>(i1, &eg));
dbg!(extract::<_, _, AstSizeNoLet>(&i1, &eg));
dbg!(&goal);
assert!(false);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/sdql/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn t1() {
let id = eg.add_syn_expr(re.clone());

apply_rewrites(&mut eg, &rewrites);
let term = extract::<_, _, AstSize>(id.clone(), &eg);
let term = extract::<_, _, AstSize>(&id, &eg);
eprintln!("{}", re.to_string());
eprintln!("{}", term.to_string());
}
2 changes: 1 addition & 1 deletion tests/sym/tst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn assert_reaches(start: &str, goal: &str, steps: usize, extra_rules: &[&'static
}
}

dbg!(extract::<_, _, AstSize>(i1, &eg));
dbg!(extract::<_, _, AstSize>(&i1, &eg));
dbg!(&goal);
assert!(false);
}
Expand Down

0 comments on commit da54faa

Please sign in to comment.