diff --git a/Cargo.toml b/Cargo.toml index b81c43f..45b7f5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,9 @@ features = ["explanations"] slotted-egraphs-derive = "=0.0.1" fnv = "1.0.7" tracing = { version = "0.1", features = ["attributes"], optional = true } +symbol_table = { version = "0.3", features = ["global"] } [dev-dependencies] -symbol_table = { version = "0.3", features = ["global"]} rand = "0.8.5" [profile.release] diff --git a/src/lang.rs b/src/lang.rs index a2a503b..f526d57 100644 --- a/src/lang.rs +++ b/src/lang.rs @@ -45,6 +45,7 @@ macro_rules! impl_slotless_lang { } impl_slotless_lang!(u32); +impl_slotless_lang!(Symbol); impl Language for Bind { fn all_slot_occurences_mut(&mut self) -> Vec<&mut Slot> { diff --git a/src/lib.rs b/src/lib.rs index 3b79c80..aa1978c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,8 @@ use std::ops::Deref; pub(crate) type HashMap = fnv::FnvHashMap; pub(crate) type HashSet = fnv::FnvHashSet; +pub use symbol_table::GlobalSymbol as Symbol; + // Whether to enable invariant-checks. #[cfg(feature = "checks")] const CHECKS: bool = true; diff --git a/tests/arith/mod.rs b/tests/arith/mod.rs index b8a96f0..ab3c608 100644 --- a/tests/arith/mod.rs +++ b/tests/arith/mod.rs @@ -15,150 +15,19 @@ pub use my_cost::*; mod const_prop; pub use const_prop::*; -#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] -pub enum Arith { - // lambda calculus: - Lam(Slot, AppliedId), - App(AppliedId, AppliedId), - Var(Slot), - Let(Slot, AppliedId, AppliedId), - - Add(AppliedId, AppliedId), - Mul(AppliedId, AppliedId), - - // rest: - Number(u32), - Symbol(Symbol), -} - -impl Language for Arith { - fn all_slot_occurences_mut(&mut self) -> Vec<&mut Slot> { - let mut out = Vec::new(); - match self { - Arith::Lam(x, b) => { - out.push(x); - out.extend(b.slots_mut()); - } - Arith::App(l, r) => { - out.extend(l.slots_mut()); - out.extend(r.slots_mut()); - } - Arith::Var(x) => { - out.push(x); - } - Arith::Let(x, t, b) => { - out.push(x); - out.extend(t.slots_mut()); - out.extend(b.slots_mut()); - } - - Arith::Add(l, r) => { - out.extend(l.slots_mut()); - out.extend(r.slots_mut()); - } - Arith::Mul(l, r) => { - out.extend(l.slots_mut()); - out.extend(r.slots_mut()); - } - Arith::Number(_) => {} - Arith::Symbol(_) => {} - } - out - } - - fn public_slot_occurences_mut(&mut self) -> Vec<&mut Slot> { - let mut out = Vec::new(); - match self { - Arith::Lam(x, b) => { - out.extend(b.slots_mut().into_iter().filter(|y| *y != x)); - } - Arith::App(l, r) => { - out.extend(l.slots_mut()); - out.extend(r.slots_mut()); - } - Arith::Var(x) => { - out.push(x); - } - Arith::Let(x, t, b) => { - out.extend(b.slots_mut().into_iter().filter(|y| *y != x)); - out.extend(t.slots_mut()); - } - Arith::Add(l, r) => { - out.extend(l.slots_mut()); - out.extend(r.slots_mut()); - } - Arith::Mul(l, r) => { - out.extend(l.slots_mut()); - out.extend(r.slots_mut()); - } - Arith::Number(_) => {} - Arith::Symbol(_) => {} - } - out - } - - fn applied_id_occurences_mut(&mut self) -> Vec<&mut AppliedId> { - match self { - Arith::Lam(_, b) => vec![b], - Arith::App(l, r) => vec![l, r], - Arith::Var(_) => vec![], - Arith::Let(_, t, b) => vec![t, b], - Arith::Add(l, r) => vec![l, r], - Arith::Mul(l, r) => vec![l, r], - Arith::Number(_) => vec![], - Arith::Symbol(_) => vec![], - } - } - - fn to_op(&self) -> (String, Vec) { - match self.clone() { - Arith::Lam(s, a) => (String::from("lam"), vec![Child::Slot(s), Child::AppliedId(a)]), - Arith::App(l, r) => (String::from("app"), vec![Child::AppliedId(l), Child::AppliedId(r)]), - Arith::Var(s) => (String::from("var"), vec![Child::Slot(s)]), - Arith::Let(s, t, b) => (String::from("let"), vec![Child::Slot(s), Child::AppliedId(t), Child::AppliedId(b)]), - Arith::Number(n) => (format!("{}", n), vec![]), - Arith::Symbol(s) => (format!("{}", s), vec![]), - Arith::Add(l, r) => (String::from("add"), vec![Child::AppliedId(l), Child::AppliedId(r)]), - Arith::Mul(l, r) => (String::from("mul"), vec![Child::AppliedId(l), Child::AppliedId(r)]), - } - } - - fn from_op(op: &str, children: Vec) -> Option { - match (op, &*children) { - ("lam", [Child::Slot(s), Child::AppliedId(a)]) => Some(Arith::Lam(*s, a.clone())), - ("app", [Child::AppliedId(l), Child::AppliedId(r)]) => Some(Arith::App(l.clone(), r.clone())), - ("var", [Child::Slot(s)]) => Some(Arith::Var(*s)), - ("let", [Child::Slot(s), Child::AppliedId(t), Child::AppliedId(b)]) => Some(Arith::Let(*s, t.clone(), b.clone())), - ("add", [Child::AppliedId(l), Child::AppliedId(r)]) => Some(Arith::Add(l.clone(), r.clone())), - ("mul", [Child::AppliedId(l), Child::AppliedId(r)]) => Some(Arith::Mul(l.clone(), r.clone())), - (op, []) => { - if let Ok(u) = op.parse::() { - Some(Arith::Number(u)) - } else { - let s: Symbol = op.parse().ok()?; - Some(Arith::Symbol(s)) - } - }, - _ => None, - } - } - -} - - -use std::fmt::*; - -impl Debug for Arith { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - match self { - Arith::Lam(s, b) => write!(f, "(lam {s:?} {b:?})"), - Arith::App(l, r) => write!(f, "(app {l:?} {r:?})"), - Arith::Var(s) => write!(f, "{s:?}"), - Arith::Let(x, t, b) => write!(f, "(let {x:?} {t:?} {b:?})"), - Arith::Add(l, r) => write!(f, "(+ {l:?} {r:?})"), - Arith::Mul(l, r) => write!(f, "(* {l:?} {r:?})"), - Arith::Number(i) => write!(f, "{i}"), - Arith::Symbol(i) => write!(f, "symb{i:?}"), - } +define_language! { + pub enum Arith { + // lambda calculus: + Lam(Bind) = "lam", + App(AppliedId, AppliedId) = "app", + Var(Slot) = "var", + Let(Bind, AppliedId) = "let", + + Add(AppliedId, AppliedId) = "add", + Mul(AppliedId, AppliedId) = "mul", + + // rest: + Number(u32), + Symbol(Symbol), } } diff --git a/tests/arith/rewrite.rs b/tests/arith/rewrite.rs index 7c29bd3..8b68faa 100644 --- a/tests/arith/rewrite.rs +++ b/tests/arith/rewrite.rs @@ -25,7 +25,7 @@ pub fn rewrite_arith(eg: &mut EGraph) { fn beta() -> Rewrite { let pat = "(app (lam $1 ?b) ?t)"; - let outpat = "(let $1 ?t ?b)"; + let outpat = "(let $1 ?b ?t)"; Rewrite::new("beta", pat, outpat) } @@ -46,7 +46,7 @@ fn eta_expansion() -> Rewrite { } fn my_let_unused() -> Rewrite { - let pat = "(let $1 ?t ?b)"; + let pat = "(let $1 ?b ?t)"; let outpat = "?b"; Rewrite::new_if("my-let-unused", pat, outpat, |subst, _| { !subst["b"].slots().contains(&Slot::numeric(1)) @@ -54,22 +54,22 @@ fn my_let_unused() -> Rewrite { } fn let_var_same() -> Rewrite { - let pat = "(let $1 ?e (var $1))"; + let pat = "(let $1 (var $1) ?e)"; let outpat = "?e"; Rewrite::new("let-var-same", pat, outpat) } fn let_app() -> Rewrite { - let pat = "(let $1 ?e (app ?a ?b))"; - let outpat = "(app (let $1 ?e ?a) (let $1 ?e ?b))"; + let pat = "(let $1 (app ?a ?b) ?e)"; + let outpat = "(app (let $1 ?a ?e) (let $1 ?b ?e))"; Rewrite::new_if("let-app", pat, outpat, |subst, _| { subst["a"].slots().contains(&Slot::numeric(1)) || subst["b"].slots().contains(&Slot::numeric(1)) }) } fn let_lam_diff() -> Rewrite { - let pat = "(let $1 ?e (lam $2 ?b))"; - let outpat = "(lam $2 (let $1 ?e ?b))"; + let pat = "(let $1 (lam $2 ?b) ?e)"; + let outpat = "(lam $2 (let $1 ?b ?e))"; Rewrite::new_if("let-lam-diff", pat, outpat, |subst, _| { subst["b"].slots().contains(&Slot::numeric(1)) }) diff --git a/tests/entry.rs b/tests/entry.rs index dc52938..d45e4f1 100644 --- a/tests/entry.rs +++ b/tests/entry.rs @@ -1,8 +1,6 @@ pub use std::hash::Hash; pub use slotted_egraphs::*; -pub use symbol_table::GlobalSymbol as Symbol; - pub type HashMap = fnv::FnvHashMap; pub type HashSet = fnv::FnvHashSet;