Skip to content

Commit

Permalink
optimize weak_shape computation
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryleak47 committed Nov 25, 2024
1 parent d9a7aef commit 99a4358
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
24 changes: 24 additions & 0 deletions slotted-egraphs-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub fn define_language(input: TokenStream1) -> TokenStream1 {
let from_syntax_arms2: Vec<TokenStream2> = ie.variants.iter().zip(&str_names).filter_map(|(x, n)| produce_from_syntax2(&name, &n, x)).collect();

let slots_arms: Vec<TokenStream2> = ie.variants.iter().map(|x| produce_slots(&name, x)).collect();
let weak_shape_inplace_arms: Vec<TokenStream2> = ie.variants.iter().map(|x| produce_weak_shape_inplace(&name, x)).collect();

quote! {
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
Expand Down Expand Up @@ -105,6 +106,16 @@ pub fn define_language(input: TokenStream1) -> TokenStream1 {
#(#slots_arms),*
}
}

#[cfg_attr(feature = "trace", tracing::instrument(name = "Lang::weak_shape_inplace", level = "trace", skip_all))]
fn weak_shape_inplace(&mut self) -> slotted_egraphs::SlotMap {
let m = &mut (slotted_egraphs::SlotMap::new(), 0);
match self {
#(#weak_shape_inplace_arms),*
}

m.0.inverse()
}
}
}.to_token_stream().into()
}
Expand Down Expand Up @@ -281,3 +292,16 @@ fn produce_slots(name: &Ident, v: &Variant) -> TokenStream2 {
}
}
}

fn produce_weak_shape_inplace(name: &Ident, v: &Variant) -> TokenStream2 {
let variant_name = &v.ident;
let n = v.fields.len();
let fields: Vec<Ident> = (0..n).map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site())).collect();
quote! {
#name::#variant_name(#(#fields),*) => {
#(
#fields .weak_shape_impl(m);
)*
}
}
}
59 changes: 38 additions & 21 deletions src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,24 @@ pub trait LanguageChildren: Debug + Clone + Hash + Eq {

fn to_syntax(&self) -> Vec<SyntaxElem>;
fn from_syntax(_: &[SyntaxElem]) -> Option<Self>;

fn weak_shape_impl(&mut self, m: &mut (SlotMap, u32)) { todo!() }
}

fn on_see_slot(s: &mut Slot, m: &mut (SlotMap, u32)) {
if let Some(s2) = m.0.get(*s) {
*s = s2;
} else {
add_slot(s, m);
}
}

fn add_slot(s: &mut Slot, m: &mut (SlotMap, u32)) {
let s2 = Slot::numeric(m.1);
m.1 += 1;
m.0.insert(*s, s2);
*s = s2;
}

impl LanguageChildren for AppliedId {
fn all_slot_occurrences_iter_mut(&mut self) -> impl Iterator<Item=&mut Slot> { self.m.values_mut() }
Expand All @@ -38,6 +54,12 @@ impl LanguageChildren for AppliedId {
_ => None,
}
}

fn weak_shape_impl(&mut self, m: &mut (SlotMap, u32)) {
for x in self.m.values_mut() {
on_see_slot(x, m);
}
}
}

impl LanguageChildren for Slot {
Expand All @@ -56,6 +78,10 @@ impl LanguageChildren for Slot {
_ => None,
}
}

fn weak_shape_impl(&mut self, m: &mut (SlotMap, u32)) {
on_see_slot(self, m);
}
}

/// Implements [LanguageChildren] for payload types that are independent of Slots. For example u32, String etc.
Expand All @@ -79,6 +105,8 @@ macro_rules! bare_language_child {
_ => None,
}
}

fn weak_shape_impl(&mut self, m: &mut (SlotMap, u32)) {}
}
)*
}
Expand Down Expand Up @@ -138,6 +166,13 @@ impl<L: LanguageChildren> LanguageChildren for Bind<L> {
elem,
})
}

fn weak_shape_impl(&mut self, m: &mut (SlotMap, u32)) {
let s = self.slot;
add_slot(&mut self.slot, m);
self.elem.weak_shape_impl(m);
m.0.remove(s);
}
}

// TODO: add LanguageChildren definition for tuples.
Expand Down Expand Up @@ -169,6 +204,7 @@ pub trait Language: Debug + Clone + Hash + Eq {
fn from_syntax(_: &[SyntaxElem]) -> Option<Self>;

fn slots(&self) -> HashSet<Slot>;
fn weak_shape_inplace(&mut self) -> Bijection;

#[track_caller]
#[doc(hidden)]
Expand Down Expand Up @@ -296,27 +332,8 @@ pub trait Language: Debug + Clone + Hash + Eq {
#[cfg_attr(feature = "trace", instrument(level = "trace", skip_all))]
fn weak_shape(&self) -> (Self, Bijection) {
let mut c = self.clone();
let mut m = SlotMap::new();
let mut i = 0;

for x in c.all_slot_occurrences_mut() {
let x_val = *x;
if !m.contains_key(x_val) {
let new_slot = Slot::numeric(i);
i += 1;

m.insert(x_val, new_slot);
}

*x = m[x_val];
}

let m = m.inverse();

let public = c.slots();
let m: SlotMap = m.iter().filter(|(x, _)| public.contains(x)).collect();

(c, m)
let bij = c.weak_shape_inplace();
(c, bij)
}

#[doc(hidden)]
Expand Down

0 comments on commit 99a4358

Please sign in to comment.