Skip to content

Commit

Permalink
add SubstMethod
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryleak47 committed Oct 21, 2024
1 parent 41601a8 commit 69240ca
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 27 deletions.
9 changes: 9 additions & 0 deletions src/egraph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,26 @@ pub struct EGraph<L: Language, N: Analysis<L> = ()> {

// TODO remove this if explanations are disabled.
pub(crate) proof_registry: ProofRegistry,

pub(crate) subst_method: Option<Box<dyn SubstMethod<L, N>>>,
}

impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Creates an empty e-graph.
pub fn new() -> Self {
Self::with_subst_method::<SynExprSubst>()
}

/// Creates an empty e-graph, while specifying the substitution method to use.
pub fn with_subst_method<S: SubstMethod<L, N>>() -> Self {
EGraph {
unionfind: Default::default(),
classes: Default::default(),
hashcons: Default::default(),
syn_hashcons: Default::default(),
pending: Default::default(),
proof_registry: ProofRegistry::default(),
subst_method: Some(S::new_boxed()),
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub use ematch::*;
mod pattern;
pub use pattern::*;

mod subst_method;
pub use subst_method::*;

/// An equational rewrite rule.
pub struct Rewrite<L: Language, N: Analysis<L> = ()> {
pub(crate) searcher: Box<dyn Fn(&EGraph<L, N>) -> Box<dyn Any>>,
Expand Down
26 changes: 5 additions & 21 deletions src/rewrite/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,15 @@ pub fn pattern_subst<L: Language, N: Analysis<L>>(eg: &mut EGraph<L, N>, pattern
let x = pattern_subst(eg, &*x, subst);
let t = pattern_subst(eg, &*t, subst);

// ast-size extraction is also an option. but slower without an e-graph analysis.
let term = eg.get_syn_expr(&eg.synify_app_id(b));

do_subst(eg, &term, &x, &t)
// temporary swap-out so that we can access both the e-graph and the subst-method fully.
let mut method = eg.subst_method.take().unwrap();
let out = method.subst(b, x, t, eg);
eg.subst_method = Some(method);
out
},
}
}

// returns re[x := t]
fn do_subst<L: Language, N: Analysis<L>>(eg: &mut EGraph<L, N>, re: &RecExpr<L>, x: &AppliedId, t: &AppliedId) -> AppliedId {
let mut n = re.node.clone();
let mut refs: Vec<&mut AppliedId> = n.applied_id_occurences_mut();
assert_eq!(re.children.len(), refs.len());
for i in 0..refs.len() {
*(refs[i]) = do_subst(eg, &re.children[i], x, t);
}
let app_id = eg.add_syn(n);

if app_id == *x {
return t.clone();
} else {
app_id
}
}

// TODO maybe move into EGraph API?
pub fn lookup_rec_expr<L: Language, N: Analysis<L>>(re: &RecExpr<L>, eg: &EGraph<L, N>) -> Option<AppliedId> {
let mut n = re.node.clone();
Expand Down
53 changes: 53 additions & 0 deletions src/rewrite/subst_method.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use crate::*;

/// Specifies a certain implementation of how substitution `b[x := t]` is implemented internally.
pub trait SubstMethod<L: Language, N: Analysis<L>> {
fn new_boxed() -> Box<dyn SubstMethod<L, N>> where Self: Sized;
fn subst(&mut self, b: AppliedId, x: AppliedId, t: AppliedId, eg: &mut EGraph<L, N>) -> AppliedId;
}

/// A [SubstMethod] that uses the [EGraph::get_syn_expr] of an e-class to do substitution on it.
pub struct SynExprSubst;

impl<L: Language, N: Analysis<L>> SubstMethod<L, N> for SynExprSubst {
fn new_boxed() -> Box<dyn SubstMethod<L, N>> {
Box::new(SynExprSubst)
}

fn subst(&mut self, b: AppliedId, x: AppliedId, t: AppliedId, eg: &mut EGraph<L, N>) -> AppliedId {
let term = eg.get_syn_expr(&eg.synify_app_id(b));
do_term_subst(eg, &term, &x, &t)
}
}

/// A [SubstMethod] that extracts the smallest term (measured by [AstSize]) of an e-class to do substitution on it.
pub struct ExtractionSubst;

impl<L: Language, N: Analysis<L>> SubstMethod<L, N> for ExtractionSubst {
fn new_boxed() -> Box<dyn SubstMethod<L, N>> {
Box::new(ExtractionSubst)
}

fn subst(&mut self, b: AppliedId, x: AppliedId, t: AppliedId, eg: &mut EGraph<L, N>) -> AppliedId {
let term = ast_size_extract::<L, N>(b, eg);
do_term_subst(eg, &term, &x, &t)
}
}

// returns re[x := t]
fn do_term_subst<L: Language, N: Analysis<L>>(eg: &mut EGraph<L, N>, re: &RecExpr<L>, x: &AppliedId, t: &AppliedId) -> AppliedId {
let mut n = re.node.clone();
let mut refs: Vec<&mut AppliedId> = n.applied_id_occurences_mut();
assert_eq!(re.children.len(), refs.len());
for i in 0..refs.len() {
*(refs[i]) = do_term_subst(eg, &re.children[i], x, t);
}
let app_id = eg.add_syn(n);

if app_id == *x {
return t.clone();
} else {
app_id
}
}

10 changes: 5 additions & 5 deletions tests/rise/rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::*;

pub enum SubstMethod {
pub enum RiseSubstMethod {
Extraction,
SmallStep,
SmallStepUnoptimized,
}

pub fn rise_rules(subst_m: SubstMethod) -> Vec<Rewrite<Rise>> {
pub fn rise_rules(subst_m: RiseSubstMethod) -> Vec<Rewrite<Rise>> {
let mut rewrites = Vec::new();

rewrites.push(eta());
Expand All @@ -23,17 +23,17 @@ pub fn rise_rules(subst_m: SubstMethod) -> Vec<Rewrite<Rise>> {
rewrites.push(separate_dot_hv_simplified());

match subst_m {
SubstMethod::Extraction => {
RiseSubstMethod::Extraction => {
rewrites.push(beta_extr_direct());
},
SubstMethod::SmallStep => {
RiseSubstMethod::SmallStep => {
rewrites.push(beta());
rewrites.push(my_let_unused());
rewrites.push(let_var_same());
rewrites.push(let_app());
rewrites.push(let_lam_diff());
},
SubstMethod::SmallStepUnoptimized => {
RiseSubstMethod::SmallStepUnoptimized => {
rewrites.push(beta());
rewrites.push(let_var_same());
rewrites.push(let_var_diff());
Expand Down
2 changes: 1 addition & 1 deletion tests/rise/tst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ fn assert_reaches(start: &str, goal: &str, steps: usize) {
let start = RecExpr::parse(start).unwrap();
let goal = RecExpr::parse(goal).unwrap();

let rules = rise_rules(SubstMethod::SmallStep);
let rules = rise_rules(RiseSubstMethod::SmallStep);

let mut eg = EGraph::new();
let i1 = eg.add_expr(start.clone());
Expand Down

0 comments on commit 69240ca

Please sign in to comment.