Skip to content

Commit

Permalink
tests/arith now uses define_language!
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryleak47 committed Nov 20, 2024
1 parent a5df7f4 commit 2ddaeac
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 155 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ features = ["explanations"]
slotted-egraphs-derive = "=0.0.1"
fnv = "1.0.7"
tracing = { version = "0.1", features = ["attributes"], optional = true }
symbol_table = { version = "0.3", features = ["global"] }

[dev-dependencies]
symbol_table = { version = "0.3", features = ["global"]}
rand = "0.8.5"

[profile.release]
Expand Down
1 change: 1 addition & 0 deletions src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ macro_rules! impl_slotless_lang {
}

impl_slotless_lang!(u32);
impl_slotless_lang!(Symbol);

impl<L: Language> Language for Bind<L> {
fn all_slot_occurences_mut(&mut self) -> Vec<&mut Slot> {
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use std::ops::Deref;
pub(crate) type HashMap<K, V> = fnv::FnvHashMap<K, V>;
pub(crate) type HashSet<T> = fnv::FnvHashSet<T>;

pub use symbol_table::GlobalSymbol as Symbol;

// Whether to enable invariant-checks.
#[cfg(feature = "checks")]
const CHECKS: bool = true;
Expand Down
159 changes: 14 additions & 145 deletions tests/arith/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,150 +15,19 @@ pub use my_cost::*;
mod const_prop;
pub use const_prop::*;

#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub enum Arith {
// lambda calculus:
Lam(Slot, AppliedId),
App(AppliedId, AppliedId),
Var(Slot),
Let(Slot, AppliedId, AppliedId),

Add(AppliedId, AppliedId),
Mul(AppliedId, AppliedId),

// rest:
Number(u32),
Symbol(Symbol),
}

impl Language for Arith {
fn all_slot_occurences_mut(&mut self) -> Vec<&mut Slot> {
let mut out = Vec::new();
match self {
Arith::Lam(x, b) => {
out.push(x);
out.extend(b.slots_mut());
}
Arith::App(l, r) => {
out.extend(l.slots_mut());
out.extend(r.slots_mut());
}
Arith::Var(x) => {
out.push(x);
}
Arith::Let(x, t, b) => {
out.push(x);
out.extend(t.slots_mut());
out.extend(b.slots_mut());
}

Arith::Add(l, r) => {
out.extend(l.slots_mut());
out.extend(r.slots_mut());
}
Arith::Mul(l, r) => {
out.extend(l.slots_mut());
out.extend(r.slots_mut());
}
Arith::Number(_) => {}
Arith::Symbol(_) => {}
}
out
}

fn public_slot_occurences_mut(&mut self) -> Vec<&mut Slot> {
let mut out = Vec::new();
match self {
Arith::Lam(x, b) => {
out.extend(b.slots_mut().into_iter().filter(|y| *y != x));
}
Arith::App(l, r) => {
out.extend(l.slots_mut());
out.extend(r.slots_mut());
}
Arith::Var(x) => {
out.push(x);
}
Arith::Let(x, t, b) => {
out.extend(b.slots_mut().into_iter().filter(|y| *y != x));
out.extend(t.slots_mut());
}
Arith::Add(l, r) => {
out.extend(l.slots_mut());
out.extend(r.slots_mut());
}
Arith::Mul(l, r) => {
out.extend(l.slots_mut());
out.extend(r.slots_mut());
}
Arith::Number(_) => {}
Arith::Symbol(_) => {}
}
out
}

fn applied_id_occurences_mut(&mut self) -> Vec<&mut AppliedId> {
match self {
Arith::Lam(_, b) => vec![b],
Arith::App(l, r) => vec![l, r],
Arith::Var(_) => vec![],
Arith::Let(_, t, b) => vec![t, b],
Arith::Add(l, r) => vec![l, r],
Arith::Mul(l, r) => vec![l, r],
Arith::Number(_) => vec![],
Arith::Symbol(_) => vec![],
}
}

fn to_op(&self) -> (String, Vec<Child>) {
match self.clone() {
Arith::Lam(s, a) => (String::from("lam"), vec![Child::Slot(s), Child::AppliedId(a)]),
Arith::App(l, r) => (String::from("app"), vec![Child::AppliedId(l), Child::AppliedId(r)]),
Arith::Var(s) => (String::from("var"), vec![Child::Slot(s)]),
Arith::Let(s, t, b) => (String::from("let"), vec![Child::Slot(s), Child::AppliedId(t), Child::AppliedId(b)]),
Arith::Number(n) => (format!("{}", n), vec![]),
Arith::Symbol(s) => (format!("{}", s), vec![]),
Arith::Add(l, r) => (String::from("add"), vec![Child::AppliedId(l), Child::AppliedId(r)]),
Arith::Mul(l, r) => (String::from("mul"), vec![Child::AppliedId(l), Child::AppliedId(r)]),
}
}

fn from_op(op: &str, children: Vec<Child>) -> Option<Self> {
match (op, &*children) {
("lam", [Child::Slot(s), Child::AppliedId(a)]) => Some(Arith::Lam(*s, a.clone())),
("app", [Child::AppliedId(l), Child::AppliedId(r)]) => Some(Arith::App(l.clone(), r.clone())),
("var", [Child::Slot(s)]) => Some(Arith::Var(*s)),
("let", [Child::Slot(s), Child::AppliedId(t), Child::AppliedId(b)]) => Some(Arith::Let(*s, t.clone(), b.clone())),
("add", [Child::AppliedId(l), Child::AppliedId(r)]) => Some(Arith::Add(l.clone(), r.clone())),
("mul", [Child::AppliedId(l), Child::AppliedId(r)]) => Some(Arith::Mul(l.clone(), r.clone())),
(op, []) => {
if let Ok(u) = op.parse::<u32>() {
Some(Arith::Number(u))
} else {
let s: Symbol = op.parse().ok()?;
Some(Arith::Symbol(s))
}
},
_ => None,
}
}

}


use std::fmt::*;

impl Debug for Arith {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
Arith::Lam(s, b) => write!(f, "(lam {s:?} {b:?})"),
Arith::App(l, r) => write!(f, "(app {l:?} {r:?})"),
Arith::Var(s) => write!(f, "{s:?}"),
Arith::Let(x, t, b) => write!(f, "(let {x:?} {t:?} {b:?})"),
Arith::Add(l, r) => write!(f, "(+ {l:?} {r:?})"),
Arith::Mul(l, r) => write!(f, "(* {l:?} {r:?})"),
Arith::Number(i) => write!(f, "{i}"),
Arith::Symbol(i) => write!(f, "symb{i:?}"),
}
define_language! {
pub enum Arith {
// lambda calculus:
Lam(Bind<AppliedId>) = "lam",
App(AppliedId, AppliedId) = "app",
Var(Slot) = "var",
Let(Bind<AppliedId>, AppliedId) = "let",

Add(AppliedId, AppliedId) = "add",
Mul(AppliedId, AppliedId) = "mul",

// rest:
Number(u32),
Symbol(Symbol),
}
}
14 changes: 7 additions & 7 deletions tests/arith/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn rewrite_arith(eg: &mut EGraph<Arith>) {

fn beta() -> Rewrite<Arith> {
let pat = "(app (lam $1 ?b) ?t)";
let outpat = "(let $1 ?t ?b)";
let outpat = "(let $1 ?b ?t)";

Rewrite::new("beta", pat, outpat)
}
Expand All @@ -46,30 +46,30 @@ fn eta_expansion() -> Rewrite<Arith> {
}

fn my_let_unused() -> Rewrite<Arith> {
let pat = "(let $1 ?t ?b)";
let pat = "(let $1 ?b ?t)";
let outpat = "?b";
Rewrite::new_if("my-let-unused", pat, outpat, |subst, _| {
!subst["b"].slots().contains(&Slot::numeric(1))
})
}

fn let_var_same() -> Rewrite<Arith> {
let pat = "(let $1 ?e (var $1))";
let pat = "(let $1 (var $1) ?e)";
let outpat = "?e";
Rewrite::new("let-var-same", pat, outpat)
}

fn let_app() -> Rewrite<Arith> {
let pat = "(let $1 ?e (app ?a ?b))";
let outpat = "(app (let $1 ?e ?a) (let $1 ?e ?b))";
let pat = "(let $1 (app ?a ?b) ?e)";
let outpat = "(app (let $1 ?a ?e) (let $1 ?b ?e))";
Rewrite::new_if("let-app", pat, outpat, |subst, _| {
subst["a"].slots().contains(&Slot::numeric(1)) || subst["b"].slots().contains(&Slot::numeric(1))
})
}

fn let_lam_diff() -> Rewrite<Arith> {
let pat = "(let $1 ?e (lam $2 ?b))";
let outpat = "(lam $2 (let $1 ?e ?b))";
let pat = "(let $1 (lam $2 ?b) ?e)";
let outpat = "(lam $2 (let $1 ?b ?e))";
Rewrite::new_if("let-lam-diff", pat, outpat, |subst, _| {
subst["b"].slots().contains(&Slot::numeric(1))
})
Expand Down
2 changes: 0 additions & 2 deletions tests/entry.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
pub use std::hash::Hash;
pub use slotted_egraphs::*;

pub use symbol_table::GlobalSymbol as Symbol;

pub type HashMap<K, V> = fnv::FnvHashMap<K, V>;
pub type HashSet<T> = fnv::FnvHashSet<T>;

Expand Down

0 comments on commit 2ddaeac

Please sign in to comment.