From 69240ca93fd6e91ecc692178f0659ce1b88ddf56 Mon Sep 17 00:00:00 2001 From: memoryleak47 Date: Mon, 21 Oct 2024 11:43:57 +0200 Subject: [PATCH] add SubstMethod --- src/egraph/mod.rs | 9 +++++++ src/rewrite/mod.rs | 3 +++ src/rewrite/pattern.rs | 26 ++++-------------- src/rewrite/subst_method.rs | 53 +++++++++++++++++++++++++++++++++++++ tests/rise/rewrite.rs | 10 +++---- tests/rise/tst.rs | 2 +- 6 files changed, 76 insertions(+), 27 deletions(-) create mode 100644 src/rewrite/subst_method.rs diff --git a/src/egraph/mod.rs b/src/egraph/mod.rs index 9ccc0c2..146efa1 100644 --- a/src/egraph/mod.rs +++ b/src/egraph/mod.rs @@ -78,10 +78,18 @@ pub struct EGraph = ()> { // TODO remove this if explanations are disabled. pub(crate) proof_registry: ProofRegistry, + + pub(crate) subst_method: Option>>, } impl> EGraph { + /// Creates an empty e-graph. pub fn new() -> Self { + Self::with_subst_method::() + } + + /// Creates an empty e-graph, while specifying the substitution method to use. + pub fn with_subst_method>() -> Self { EGraph { unionfind: Default::default(), classes: Default::default(), @@ -89,6 +97,7 @@ impl> EGraph { syn_hashcons: Default::default(), pending: Default::default(), proof_registry: ProofRegistry::default(), + subst_method: Some(S::new_boxed()), } } diff --git a/src/rewrite/mod.rs b/src/rewrite/mod.rs index 67fbc99..b89a447 100644 --- a/src/rewrite/mod.rs +++ b/src/rewrite/mod.rs @@ -7,6 +7,9 @@ pub use ematch::*; mod pattern; pub use pattern::*; +mod subst_method; +pub use subst_method::*; + /// An equational rewrite rule. pub struct Rewrite = ()> { pub(crate) searcher: Box) -> Box>, diff --git a/src/rewrite/pattern.rs b/src/rewrite/pattern.rs index 01cdce9..9ba8bfb 100644 --- a/src/rewrite/pattern.rs +++ b/src/rewrite/pattern.rs @@ -31,31 +31,15 @@ pub fn pattern_subst>(eg: &mut EGraph, pattern let x = pattern_subst(eg, &*x, subst); let t = pattern_subst(eg, &*t, subst); - // ast-size extraction is also an option. but slower without an e-graph analysis. - let term = eg.get_syn_expr(&eg.synify_app_id(b)); - - do_subst(eg, &term, &x, &t) + // temporary swap-out so that we can access both the e-graph and the subst-method fully. + let mut method = eg.subst_method.take().unwrap(); + let out = method.subst(b, x, t, eg); + eg.subst_method = Some(method); + out }, } } -// returns re[x := t] -fn do_subst>(eg: &mut EGraph, re: &RecExpr, x: &AppliedId, t: &AppliedId) -> AppliedId { - let mut n = re.node.clone(); - let mut refs: Vec<&mut AppliedId> = n.applied_id_occurences_mut(); - assert_eq!(re.children.len(), refs.len()); - for i in 0..refs.len() { - *(refs[i]) = do_subst(eg, &re.children[i], x, t); - } - let app_id = eg.add_syn(n); - - if app_id == *x { - return t.clone(); - } else { - app_id - } -} - // TODO maybe move into EGraph API? pub fn lookup_rec_expr>(re: &RecExpr, eg: &EGraph) -> Option { let mut n = re.node.clone(); diff --git a/src/rewrite/subst_method.rs b/src/rewrite/subst_method.rs new file mode 100644 index 0000000..55888f9 --- /dev/null +++ b/src/rewrite/subst_method.rs @@ -0,0 +1,53 @@ +use crate::*; + +/// Specifies a certain implementation of how substitution `b[x := t]` is implemented internally. +pub trait SubstMethod> { + fn new_boxed() -> Box> where Self: Sized; + fn subst(&mut self, b: AppliedId, x: AppliedId, t: AppliedId, eg: &mut EGraph) -> AppliedId; +} + +/// A [SubstMethod] that uses the [EGraph::get_syn_expr] of an e-class to do substitution on it. +pub struct SynExprSubst; + +impl> SubstMethod for SynExprSubst { + fn new_boxed() -> Box> { + Box::new(SynExprSubst) + } + + fn subst(&mut self, b: AppliedId, x: AppliedId, t: AppliedId, eg: &mut EGraph) -> AppliedId { + let term = eg.get_syn_expr(&eg.synify_app_id(b)); + do_term_subst(eg, &term, &x, &t) + } +} + +/// A [SubstMethod] that extracts the smallest term (measured by [AstSize]) of an e-class to do substitution on it. +pub struct ExtractionSubst; + +impl> SubstMethod for ExtractionSubst { + fn new_boxed() -> Box> { + Box::new(ExtractionSubst) + } + + fn subst(&mut self, b: AppliedId, x: AppliedId, t: AppliedId, eg: &mut EGraph) -> AppliedId { + let term = ast_size_extract::(b, eg); + do_term_subst(eg, &term, &x, &t) + } +} + +// returns re[x := t] +fn do_term_subst>(eg: &mut EGraph, re: &RecExpr, x: &AppliedId, t: &AppliedId) -> AppliedId { + let mut n = re.node.clone(); + let mut refs: Vec<&mut AppliedId> = n.applied_id_occurences_mut(); + assert_eq!(re.children.len(), refs.len()); + for i in 0..refs.len() { + *(refs[i]) = do_term_subst(eg, &re.children[i], x, t); + } + let app_id = eg.add_syn(n); + + if app_id == *x { + return t.clone(); + } else { + app_id + } +} + diff --git a/tests/rise/rewrite.rs b/tests/rise/rewrite.rs index 9051a1e..a0091b6 100644 --- a/tests/rise/rewrite.rs +++ b/tests/rise/rewrite.rs @@ -1,12 +1,12 @@ use crate::*; -pub enum SubstMethod { +pub enum RiseSubstMethod { Extraction, SmallStep, SmallStepUnoptimized, } -pub fn rise_rules(subst_m: SubstMethod) -> Vec> { +pub fn rise_rules(subst_m: RiseSubstMethod) -> Vec> { let mut rewrites = Vec::new(); rewrites.push(eta()); @@ -23,17 +23,17 @@ pub fn rise_rules(subst_m: SubstMethod) -> Vec> { rewrites.push(separate_dot_hv_simplified()); match subst_m { - SubstMethod::Extraction => { + RiseSubstMethod::Extraction => { rewrites.push(beta_extr_direct()); }, - SubstMethod::SmallStep => { + RiseSubstMethod::SmallStep => { rewrites.push(beta()); rewrites.push(my_let_unused()); rewrites.push(let_var_same()); rewrites.push(let_app()); rewrites.push(let_lam_diff()); }, - SubstMethod::SmallStepUnoptimized => { + RiseSubstMethod::SmallStepUnoptimized => { rewrites.push(beta()); rewrites.push(let_var_same()); rewrites.push(let_var_diff()); diff --git a/tests/rise/tst.rs b/tests/rise/tst.rs index afe23fb..9964952 100644 --- a/tests/rise/tst.rs +++ b/tests/rise/tst.rs @@ -4,7 +4,7 @@ fn assert_reaches(start: &str, goal: &str, steps: usize) { let start = RecExpr::parse(start).unwrap(); let goal = RecExpr::parse(goal).unwrap(); - let rules = rise_rules(SubstMethod::SmallStep); + let rules = rise_rules(RiseSubstMethod::SmallStep); let mut eg = EGraph::new(); let i1 = eg.add_expr(start.clone());