Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple set of addition rewrites already explodes #334

Open
KarelPeeters opened this issue Sep 1, 2024 · 7 comments
Open

Simple set of addition rewrites already explodes #334

KarelPeeters opened this issue Sep 1, 2024 · 7 comments

Comments

@KarelPeeters
Copy link

KarelPeeters commented Sep 1, 2024

Hey! Thanks for doing the equivalence graph research and building this crate as an implementation.

I'm investigating how easy it is to use egg as part of the integer range type system for an experimental language I'm working on. Right now I'm trying to see fast achieving saturation is for simple arithmetic expressions, only involving addition and negation. This is the full test code I'm running:

use egg::{rewrite, Rewrite, Runner, SymbolLang};

fn main() {
    let mut rules: Vec<Rewrite<SymbolLang, ()>> = vec![];
    rules.push(rewrite!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"));
    rules.extend(rewrite!("assoc-add"; "(+ (+ ?x ?y) ?z)" <=> "(+ ?x (+ ?y ?z))"));
    rules.push(rewrite!("cancel-add-neg"; "(+ ?x (- ?x))" => "0"));
    rules.push(rewrite!("add-0"; "(+ ?x 0)" => "?x"));

    let expr = "(+ a (+ b (- a)))";
    let runner = Runner::default()
        .with_explanations_enabled()
        .with_expr(&expr.parse().unwrap())
        .with_node_limit(64)
        .run(&rules);

    runner.print_report();
    let mut egraph = runner.egraph;
    
    println!("Generated classes:");
    for class in egraph.classes() {
        println!("  {}", egraph.id_to_expr(class.id));
    }

    println!("Explanation for final generated class:");
    let newest_id = egraph.classes().map(|c| c.id).max().unwrap();
    let newest_expr = egraph.id_to_expr(newest_id);
    println!("{}", egraph.explain_existance(&newest_expr).get_flat_string());
}

This already explodes, without a node or time limit the graph keeps growing forever (as far as I can tell). The set of rules I included here is minimal: removing any single rule fixes the "exploding graph" problem, at the cost of severely reducing expressiveness. This is the shortened output of running this program with all 4 rules:

Runner report
=============
  Stop reason: NodeLimit(192)
  Iterations: 6
  Egraph size: 179 nodes, 58 classes, 192 memo
  Rebuilds: 0
  Total time: 0.033573627
    Search:  (0.05) 0.001724439
    Apply:   (0.25) 0.008356102
    Rebuild: (0.70) 0.023483781000000002

Generated classes:
  a
  b
  (- a)
  (+ b (- a))
  (+ a b)
  (+ a (- a))
  (+ (+ a (- a)) a)
  (+ (+ a b) a)
  (+ a a)
  (+ (+ a (- a)) (+ a (- a)))
  (+ (- a) (+ a (- a)))
  (+ (- a) (- a))
  (+ (- a) (+ b (- a)))
  (+ (+ (+ a (- a)) a) a)
  (+ (+ (+ a b) a) a)
  (+ (+ a a) a)
  (+ (+ (+ a (- a)) (+ a (- a))) a)
  (+ (+ (- a) (+ a (- a))) a)
  (+ (+ a (- a)) (+ (+ a (- a)) a))
  (+ (+ (+ a (- a)) (+ a (- a))) (+ (+ a (- a)) a))
  (+ a (+ (+ a (- a)) a))
  (+ (+ (+ a (- a)) a) (+ (+ a (- a)) a))
  (+ (+ (- a) (+ a (- a))) (+ (+ a (- a)) a))
  (+ (+ (- a) (+ a (- a))) (+ a a))
  (+ (+ a (- a)) (+ a a))
  (+ (+ (- a) (- a)) (+ a a))
  (+ (+ a b) (+ a a))
  (+ a (+ a a))
  (+ (+ (+ a (- a)) a) (+ a (- a)))
  (+ (+ a a) (+ a (- a)))
  (+ (+ (+ a (- a)) (+ a (- a))) (+ a (- a)))
  (+ (+ (- a) (+ a (- a))) (+ a (- a)))
  (+ (+ a (- a)) (+ (+ a (- a)) (+ a (- a))))
  (+ (+ (+ a (- a)) (+ a (- a))) (+ (+ a (- a)) (+ a (- a))))
  (+ a (+ (+ a (- a)) (+ a (- a))))
  (+ (+ (+ a (- a)) a) (+ (+ a (- a)) (+ a (- a))))
  (+ (- a) (+ (+ a (- a)) (+ a (- a))))
  (+ (+ (- a) (+ a (- a))) (+ (+ a (- a)) (+ a (- a))))
  (+ (+ (- a) (- a)) a)
  (+ (+ (- a) (- a)) (+ (+ a (- a)) a))
  (+ (+ (+ a (- a)) a) (- a))
  (+ (+ a a) (- a))
  (+ (+ (+ a (- a)) a) (+ (- a) (+ a (- a))))
  [...]
Explanation for final generated class:
(+ a (+ b (- a)))
(+ a (+ (Rewrite<= add-0 (+ b 0)) (- a)))
(+ a (+ (+ b (Rewrite<= cancel-add-neg (+ a (- a)))) (- a)))
(+ a (+ (+ b (Rewrite=> commute-add (+ (- a) a))) (- a)))
(+ a (+ (Rewrite<= assoc-add (+ (+ b (- a)) a)) (- a)))
(+ a (+ (Rewrite<= commute-add (+ a (+ b (- a)))) (- a)))
(+ a (+ (Rewrite=> assoc-add-rev (+ (+ a b) (- a))) (- a)))
(Rewrite=> assoc-add (+ (+ a b) (+ (- a) (- a))))
(+ (Rewrite=> commute-add (+ b a)) (+ (- a) (- a)))
(+ (+ (Rewrite<= add-0 (+ b 0)) a) (+ (- a) (- a)))
(+ (+ (+ b (Rewrite<= cancel-add-neg (+ a (- a)))) a) (+ (- a) (- a)))
(+ (+ (+ b (Rewrite=> commute-add (+ (- a) a))) a) (+ (- a) (- a)))
(+ (+ (Rewrite<= assoc-add (+ (+ b (- a)) a)) a) (+ (- a) (- a)))
(+ (+ (Rewrite<= commute-add (+ a (+ b (- a)))) a) (+ (- a) (- a)))
(+ (+ (Rewrite=> assoc-add-rev (+ (+ a b) (- a))) a) (+ (- a) (- a)))
(+ (+ (Rewrite=> commute-add (+ (- a) (+ a b))) a) (+ (- a) (- a)))
(+ (Rewrite=> assoc-add (+ (- a) (+ (+ a b) a))) (+ (- a) (- a)))
(+ (Rewrite=> commute-add (+ (+ (+ a b) a) (- a))) (+ (- a) (- a)))
(Rewrite=> assoc-add (+ (+ (+ a b) a) (+ (- a) (+ (- a) (- a)))))

To me it looks like it keeps generating new classes that are equivalent to classes it already knows about, faster than it can prove that they are indeed equivalent. For example, the final class that is printed, (+ (+ (+ a b) a) (+ (- a) (+ (- a) (- a))))), is really just (+ b (- a)) again, but they have not yet been proven to be the same.

I'm surprised by steps like (Rewrite<= add-0 (+ b 0)) being included in the existence reason. I defined the one-directional rewrite rewrite!("add-0"; "(+ ?x 0)" => "?x"), why is it being used in the opposite direction too? Is it expected behavior that simple sets of rules already blow up the graph forever?

@mwillsey
Copy link
Member

mwillsey commented Sep 1, 2024

Yes, associativity (especially together with commutativity) are known to blow up the e-graph. There are various strategies to control this, but they basically amount to artificial limitations on the equality saturation process.

The problem is that equality saturation not only reasons about the (sub)terms that you give it, but it creates many new terms along the way (terms that are not equivalent to anything in the input) and reasons about them too! Associativity is a good example: once you rewrite a + (b + c) to (a + b) + c, you've made a new term a + b that is not equivalent to anything in the input!

This problem is particularly bad with associativity, because you're discovering a million new ways to rewrite 0. So one approach people take is to prune 0's e-class:

#[test]
fn foo() {
    #[derive(Default)]
    struct MyAnalysis;

    type EGraph = egg::EGraph<SymbolLang, MyAnalysis>;

    impl Analysis<SymbolLang> for MyAnalysis {
        type Data = ();

        fn make(egraph: &mut EGraph, enode: &SymbolLang) -> Self::Data {
            ()
        }

        fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
            DidMerge(false, false)
        }

        fn modify(egraph: &mut EGraph, id: Id) {
            let nodes = &mut egraph[id].nodes;
            if nodes.iter().any(|n| n.is_leaf()) {
                nodes.retain(|n| n.is_leaf());
            }
        }
    }

    let mut rules: Vec<Rewrite<SymbolLang, MyAnalysis>> = vec![];
    rules.push(rewrite!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"));
    rules.extend(rewrite!("assoc-add"; "(+ (+ ?x ?y) ?z)" <=> "(+ ?x (+ ?y ?z))"));
    rules.push(rewrite!("cancel-add-neg"; "(+ ?x (- ?x))" => "0"));
    rules.push(rewrite!("add-0"; "(+ ?x 0)" => "?x"));

    let expr = "(+ a (+ b (- a)))";
    let runner = Runner::default()
        .with_explanations_enabled()
        .with_expr(&expr.parse().unwrap())
        .with_node_limit(64)
        .run(&rules);

    runner.print_report();
    let mut egraph = runner.egraph;

    println!("Generated classes:");
    for class in egraph.classes() {
        println!("  {}", egraph.id_to_expr(class.id));
    }

    println!("Explanation for final generated class:");
    let newest_id = egraph.classes().map(|c| c.id).max().unwrap();
    let newest_expr = egraph.id_to_expr(newest_id);
    println!(
        "{}",
        egraph.explain_existance(&newest_expr).get_flat_string()
    );
}

All I have done here is add a trivial e-class analysis that does nothing, except for use the modify hook to prune any e-class that has leaf e-nodes in it. This works on the assumption that once you prove something equivalent to a constant/variable, you don't need to rewrite it any further.

@KarelPeeters
Copy link
Author

KarelPeeters commented Sep 1, 2024

Thanks for the quick response!

The problem is that equality saturation not only reasons about the (sub)terms that you give it, but it creates many new terms along the way (terms that are not equivalent to anything in the input) and reasons about them too! Associativity is a good example: once you rewrite a + (b + c) to (a + b) + c, you've made a new term a + b that is not equivalent to anything in the input!

What's still surprising to me is that all of the rewrite rules I included only reduce or keep the number of operands of an expression, so how is it possible that the nodes keep growing indefinitely?

This problem is particularly bad with associativity, because you're discovering a million new ways to rewrite 0. So one approach people take is to prune 0's e-class: [...]

Thanks, that does work, and the underlying idea makes sense: if something is already proven to be a constant or a variable it's no longer possible to simplify it further (or in my case to put it into a form where it's easier to derive range information).

It stops working again when I slightly expand the ruleset again though, in particular with distributivity between addition and multiplication. This rewrite rule definitely needs to be able to temporarily expand the size of nodes to enable certain optimizations though, so maybe it's to much to ever expect saturation here.

// keep previous rules, add new rules below
rules.push(rewrite!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"));
rules.extend(rewrite!("assoc-mul"; "(* (* ?x ?y) ?z)" <=> "(* ?x (* ?y ?z))"));
rules.extend(rewrite!("distr-add-mul"; "(+ (* ?x ?z) (* ?y ?z))" <=> "(* (+ ?x ?y) ?z)"));
rules.push(rewrite!("mul-0"; "(* ?x 0)" => "0"));
rules.push(rewrite!("mul-1"; "(* ?x 1)" => "?x"));

// new test expression
let expr = "(+ (* a (+ a b)) (- (* a b)))";

One pruning idea I'm thinking of is to prevent the AST size to increase beyond eg. twice the size of the largest original expression. That should limit the space to be at least be finite, and hopefully does not prevent too many useful simplifications.

In general, is saturation expected to be unreachable? Or will there always be some neat set of tricks to prune the space to prevent endless superfluous expansion?

@mwillsey
Copy link
Member

mwillsey commented Sep 1, 2024

Saturation will be unreachable is many use cases that involve algebraic rules like that. Most use cases are just about getting something "useful" done in the time you're willing to spend.

@KarelPeeters
Copy link
Author

KarelPeeters commented Sep 3, 2024

I see, thanks for the info!

I've tried my approach of limiting the AST size instead of using some time or node limit, and it seems to be working well. I'm sharing it here in case anyone ever wants to do something similar.

The approach is:

  • Define some custom analysis that computes the AST size of the smallest possible representation of each eclass. This is faster than continuously calling the built-in AstSize, since it nicely works incrementally (I think).
  • Add a custom condition to every rewrite rule that only accepts the rewrite if the AST size of the new expression is below the limit.

Some example code implementing this is available here, or also in the spoiler below.

Spoiler
use egg::{rewrite, Analysis, AstSize, Condition, CostFunction, DidMerge, Extractor, Id, Rewrite, Runner, Subst, SymbolLang, Var};
use itertools::Itertools;
use std::cmp::min;
use std::str::FromStr;
use std::time::Duration;

fn main() {
    let mut rules: Vec<Rewrite<SymbolLang, _>> = vec![];

    let expr = "(+ (* a (+ a b)) (- (* a b)))";
    let expr = expr.parse().unwrap();

    let max_ast_size = AstSize.cost_rec(&expr) * 2;
    let build_check = |base, vars| build_check(max_ast_size, base, vars);

    rules.push(rewrite!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)" if build_check(1, "?x ?y")));
    rules.extend(rewrite!("assoc-add"; "(+ (+ ?x ?y) ?z)" <=> "(+ ?x (+ ?y ?z))" if build_check(2, "?x ?y ?z")));
    rules.push(rewrite!("cancel-add-neg"; "(+ ?x (- ?x))" => "0" if build_check(0, "")));
    rules.push(rewrite!("add-0"; "(+ ?x 0)" => "?x" if build_check(0, "?x")));

    rules.push(rewrite!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)" if build_check(1, "?x ?y")));
    rules.extend(rewrite!("assoc-mul"; "(* (* ?x ?y) ?z)" <=> "(* ?x (* ?y ?z))" if build_check(2, "?x ?y ?z")));
    rules.push(rewrite!("distr-add-mul-fwd"; "(+ (* ?x ?z) (* ?y ?z))" => "(* (+ ?x ?y) ?z)" if build_check(2, "?x ?y ?z")));
    rules.push(rewrite!("distr-add-mul-back"; "(* (+ ?x ?y) ?z)" => "(+ (* ?x ?z) (* ?y ?z))" if build_check(3, "?x ?y ?z ?z")));
    rules.push(rewrite!("mul-0"; "(* ?x 0)" => "0" if build_check(0, "")));
    rules.push(rewrite!("mul-1"; "(* ?x 1)" => "?x" if build_check(0, "?x")));

    let runner: Runner<SymbolLang, AstSizeAnalysis> = Runner::new(AstSizeAnalysis)
        .with_explanations_enabled()
        .with_expr(&expr)
        .with_iter_limit(usize::MAX)
        .with_node_limit(usize::MAX)
        .with_time_limit(Duration::MAX)
        .run(&rules);

    runner.print_report();
    println!();
    let egraph = runner.egraph;
    let extractor = Extractor::new(&egraph, AstSize);

    println!("Generated classes:");
    for class in egraph.classes().sorted_by_key(|c| c.id) {
        let (_, expr) = extractor.find_best(class.id);
        println!("  Id({}) size={} expr={}", class.id, class.data, expr);
    }
    println!();

    println!("Original expression: {} => {}", expr, extractor.find_best(egraph.lookup_expr(&expr).unwrap()).1);
}

type EGraph = egg::EGraph<SymbolLang, AstSizeAnalysis>;

struct AstSizeAnalysis;

fn build_check(max: usize, base: usize, vars: &str) -> impl Condition<SymbolLang, AstSizeAnalysis> {
    let vars = if vars.is_empty() {
        vec![]
    } else {
        vars.split(' ')
            .map(|v| Var::from_str(v).unwrap())
            .collect_vec()
    };

    struct ConditionImpl {
        base: usize,
        vars: Vec<Var>,
        max: usize,
    }

    impl Condition<SymbolLang, AstSizeAnalysis> for ConditionImpl {
        fn check(&self, egraph: &mut egg::EGraph<SymbolLang, AstSizeAnalysis>, eclass: Id, subst: &Subst) -> bool {
            let min_size = egraph[eclass].data;
            if min_size == 1 {
                return false;
            }

            let operand_sum = self.vars.iter()
                .map(|&v| egraph[*subst.get(v).unwrap()].data)
                .sum::<usize>();
            let new_size = self.base + operand_sum;

            new_size <= self.max
        }

        fn vars(&self) -> Vec<Var> {
            self.vars.clone()
        }
    }

    ConditionImpl { max, base, vars }
}

impl Analysis<SymbolLang> for AstSizeAnalysis {
    type Data = usize;

    fn make(egraph: &mut EGraph, s: &SymbolLang) -> Self::Data {
        let SymbolLang { op: _, children } = s;
        children.iter().copied().fold(1, |a, c| a + egraph[c].data)
    }

    fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
        let new = min(*a, b);
        let merge = DidMerge(new != *a, new != b);
        *a = new;
        merge
    }
}

@mwillsey
Copy link
Member

mwillsey commented Sep 4, 2024

Good idea! I like that idea of limiting the AST size. Thanks for sharing!

@yihozhang
Copy link
Contributor

yihozhang commented Sep 5, 2024

This is very cool! I had the same idea in my EGRAPHS talk (starting page 29), I did some experiments but found it didn't work well for me because this is still doubly exponential. But maybe I only tested extreme cases (e.g., I added x+0=x so the e-graph is very cyclic). I'm excited to see how this would work for you!

edit: actually, I only tested it for depth, and maybe limiting the depth is a dumb idea and ASTs work way better

@mwillsey
Copy link
Member

mwillsey commented Sep 5, 2024

Both of you may also be interested in this un-released feature to forbid returning certain cycles from e-matching: https://github.com/egraphs-good/egg/blob/main/src/language.rs#L781

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants