From da54faa2c7a41ab0a6bf44d3a52f05796ec104c1 Mon Sep 17 00:00:00 2001 From: memoryleak47 Date: Mon, 28 Oct 2024 17:12:06 +0100 Subject: [PATCH] add cost_rec and make extractor receive &AppliedId --- src/extract/cost.rs | 10 ++++++++++ src/extract/mod.rs | 19 +++++++++++++------ src/rewrite/subst_method.rs | 2 +- tests/array/tst.rs | 4 ++-- tests/lambda/realization.rs | 6 +++--- tests/rise/rewrite.rs | 8 ++++---- tests/rise/tst.rs | 2 +- tests/sdql/rewrite.rs | 2 +- tests/sym/tst.rs | 2 +- 9 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/extract/cost.rs b/src/extract/cost.rs index 8799ee0..81c632d 100644 --- a/src/extract/cost.rs +++ b/src/extract/cost.rs @@ -8,6 +8,16 @@ use std::marker::PhantomData; pub trait CostFunction { type Cost: Ord + Clone + Debug; fn cost(&self, enode: &L, costs: C) -> Self::Cost where C: Fn(Id) -> Self::Cost; + + fn cost_rec(&self, expr: &RecExpr) -> Self::Cost { + let child_costs: Vec = expr.children.iter().map(|x| self.cost_rec(x)).collect(); + let c = |i: Id| child_costs[i.0].clone(); + let mut node = expr.node.clone(); + for (i, x) in node.applied_id_occurences_mut().iter_mut().enumerate() { + **x = AppliedId::new(Id(i), SlotMap::new()); + } + self.cost(&node, c) + } } /// The 'default' [CostFunction]. It measures the size of the abstract syntax tree of the corresponding term. diff --git a/src/extract/mod.rs b/src/extract/mod.rs index a133929..25012da 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -62,15 +62,15 @@ impl> Extractor { Self { map } } - pub fn extract>(&self, i: AppliedId, eg: &EGraph) -> RecExpr { - let i = eg.find_applied_id(&i); + pub fn extract>(&self, i: &AppliedId, eg: &EGraph) -> RecExpr { + let i = eg.find_applied_id(i); let mut children = Vec::new(); // do I need to refresh some slots here? let l = self.map[&i.id].0.apply_slotmap(&i.m); for child in l.applied_id_occurences() { - let n = self.extract(child, eg); + let n = self.extract(&child, eg); children.push(n); } @@ -81,12 +81,19 @@ impl> Extractor { } } -pub fn ast_size_extract>(i: AppliedId, eg: &EGraph) -> RecExpr { +pub fn ast_size_extract>(i: &AppliedId, eg: &EGraph) -> RecExpr { extract::(i, eg) } // `i` is not allowed to have free variables, hence prefer `Id` over `AppliedId`. -pub fn extract, CF: CostFunction + Default>(i: AppliedId, eg: &EGraph) -> RecExpr { +pub fn extract, CF: CostFunction + Default>(i: &AppliedId, eg: &EGraph) -> RecExpr { let cost_fn = CF::default(); - Extractor::::new(eg, cost_fn).extract(i, eg) + let extractor = Extractor::::new(eg, cost_fn); + let out = extractor.extract(&i, eg); + if CHECKS { + let i = eg.find_id(i.id); + let cost_fn = CF::default(); + assert_eq!(cost_fn.cost_rec(&out), extractor.map[&i].1); + } + out } diff --git a/src/rewrite/subst_method.rs b/src/rewrite/subst_method.rs index 94fd38c..5783177 100644 --- a/src/rewrite/subst_method.rs +++ b/src/rewrite/subst_method.rs @@ -29,7 +29,7 @@ impl> SubstMethod for ExtractionSubst { } fn subst(&mut self, b: AppliedId, x: AppliedId, t: AppliedId, eg: &mut EGraph) -> AppliedId { - let term = ast_size_extract::(b, eg); + let term = ast_size_extract::(&b, eg); do_term_subst(eg, &term, &x, &t) } } diff --git a/tests/array/tst.rs b/tests/array/tst.rs index 11cfc3e..f45b1ac 100644 --- a/tests/array/tst.rs +++ b/tests/array/tst.rs @@ -8,7 +8,7 @@ fn normalize(re: RecExpr) -> RecExpr { for _ in 0..40 { apply_rewrites(&mut eg, &rules); } - extract::<_, _, AstSizeNoLet>(i, &eg) + extract::<_, _, AstSizeNoLet>(&i, &eg) } fn assert_reaches(start: &str, goal: &str, steps: usize, rules: &[&'static str]) { @@ -32,7 +32,7 @@ fn assert_reaches(start: &str, goal: &str, steps: usize, rules: &[&'static str]) } } - dbg!(extract::<_, _, AstSizeNoLet>(i1, &eg)); + dbg!(extract::<_, _, AstSizeNoLet>(&i1, &eg)); dbg!(&goal); assert!(false); } diff --git a/tests/lambda/realization.rs b/tests/lambda/realization.rs index ba1288e..d3ac5c1 100644 --- a/tests/lambda/realization.rs +++ b/tests/lambda/realization.rs @@ -16,7 +16,7 @@ pub fn simplify_to_nf(s: &str) -> String { for _ in 0..NO_ITERS { R::step(&mut eg); - re = extract_ast(&eg, i.clone()); + re = extract_ast(&eg, &i); if lam_step(&re).is_none() { #[cfg(feature = "explanations")] eg.explain_equivalence(orig_re, re.clone()); @@ -41,7 +41,7 @@ pub fn simplify(s: &str) -> String { break; } } - let out = extract_ast(&eg, i.clone()); + let out = extract_ast(&eg, &i); #[cfg(feature = "explanations")] eg.explain_equivalence(re.clone(), out.clone()); @@ -95,7 +95,7 @@ pub fn check_eq(s1: &str, s2: &str) { // Non-Realization functions: -fn extract_ast(eg: &EGraph, i: AppliedId) -> RecExpr { +fn extract_ast(eg: &EGraph, i: &AppliedId) -> RecExpr { extract::<_, _, AstSizeNoLet>(i, eg) } diff --git a/tests/rise/rewrite.rs b/tests/rise/rewrite.rs index 53173ae..d39cc7f 100644 --- a/tests/rise/rewrite.rs +++ b/tests/rise/rewrite.rs @@ -232,8 +232,8 @@ fn beta_extr() -> Rewrite { let mut out: Vec<(Subst, RecExpr)> = Vec::new(); for subst in ematch_all(eg, &a) { - let b = extractor.extract(subst["b"].clone(), eg); - let t = extractor.extract(subst["t"].clone(), eg); + let b = extractor.extract(&subst["b"], eg); + let t = extractor.extract(&subst["t"], eg); let res = re_subst(s, b, &t); out.push((subst, res)); } @@ -266,8 +266,8 @@ fn beta_extr_direct() -> Rewrite { let mut out: Vec<(Subst, RecExpr)> = Vec::new(); for subst in ematch_all(eg, &a) { - let b = extractor.extract(subst["b"].clone(), eg); - let t = extractor.extract(subst["t"].clone(), eg); + let b = extractor.extract(&subst["b"], eg); + let t = extractor.extract(&subst["t"], eg); let res = re_subst(s, b, &t); out.push((subst, res)); } diff --git a/tests/rise/tst.rs b/tests/rise/tst.rs index 9964952..c368beb 100644 --- a/tests/rise/tst.rs +++ b/tests/rise/tst.rs @@ -21,7 +21,7 @@ fn assert_reaches(start: &str, goal: &str, steps: usize) { } } - dbg!(extract::<_, _, AstSizeNoLet>(i1, &eg)); + dbg!(extract::<_, _, AstSizeNoLet>(&i1, &eg)); dbg!(&goal); assert!(false); } diff --git a/tests/sdql/rewrite.rs b/tests/sdql/rewrite.rs index 7d7a5c3..98b8b8b 100644 --- a/tests/sdql/rewrite.rs +++ b/tests/sdql/rewrite.rs @@ -26,7 +26,7 @@ fn t1() { let id = eg.add_syn_expr(re.clone()); apply_rewrites(&mut eg, &rewrites); - let term = extract::<_, _, AstSize>(id.clone(), &eg); + let term = extract::<_, _, AstSize>(&id, &eg); eprintln!("{}", re.to_string()); eprintln!("{}", term.to_string()); } diff --git a/tests/sym/tst.rs b/tests/sym/tst.rs index 99b8652..1129e9a 100644 --- a/tests/sym/tst.rs +++ b/tests/sym/tst.rs @@ -46,7 +46,7 @@ fn assert_reaches(start: &str, goal: &str, steps: usize, extra_rules: &[&'static } } - dbg!(extract::<_, _, AstSize>(i1, &eg)); + dbg!(extract::<_, _, AstSize>(&i1, &eg)); dbg!(&goal); assert!(false); }