Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function inlining #463

Merged
merged 22 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0f8fba9
Function inlining in Rust
kirstenmg Apr 24, 2024
5757cab
Restore previous context adding
kirstenmg Apr 24, 2024
6aa8cc7
Replace context of children in rust subst
kirstenmg Apr 24, 2024
f525152
Deduplicate function inlining; rust code is too slow but egglog runs …
kirstenmg Apr 24, 2024
e663dd6
Try not replacing context; improves speed
kirstenmg Apr 24, 2024
790115d
Fix slow function inlining by actually caching in subst and not repla…
kirstenmg Apr 24, 2024
d5234a5
Clean up lib.rs
kirstenmg Apr 24, 2024
7da21b5
Add snapshots
kirstenmg Apr 24, 2024
be37526
Clean up code
kirstenmg Apr 24, 2024
7b0a23e
Update snapshots due to rebase
kirstenmg Apr 24, 2024
2dcecfb
Split get_calls_and_subst
kirstenmg Apr 25, 2024
2a0e4e6
Fix iteration bug in function inlining
kirstenmg Apr 26, 2024
a532714
Make egglog generation for func inlining fast by not hashing RcExprs
kirstenmg Apr 26, 2024
b0a8109
Make egglog faster for func inlining by sharing intermediate cache
kirstenmg Apr 26, 2024
eb7ca9d
Remove unneeded var_count
kirstenmg Apr 26, 2024
e3212e9
Share more caches
kirstenmg Apr 26, 2024
dea1395
Decrease function inlining to 2 iterations
kirstenmg Apr 26, 2024
161ba11
Factor out function inlining config
kirstenmg Apr 26, 2024
4f3245b
Clean up code; even fewer intermediates generated
kirstenmg Apr 26, 2024
83379ba
Refactor, remove silly method
kirstenmg Apr 26, 2024
4fb73f9
Merge with main
kirstenmg Apr 26, 2024
43b5e05
Refactor to work with context on leaves
kirstenmg Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dag_in_context/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) const FUNCTION_INLINING_ITERATIONS: usize = 2;
66 changes: 61 additions & 5 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ use greedy_dag_extractor::{extract, serialized_egraph, DefaultCostModel};
use interpreter::Value;
use schema::TreeProgram;
use std::fmt::Write;
use to_egglog::TreeToEgglog;

use crate::{interpreter::interpret_dag_prog, schedule::mk_schedule};
use crate::{
interpreter::interpret_dag_prog, optimizations::function_inlining, schedule::mk_schedule,
};

pub(crate) mod add_context;
pub mod ast;
mod config;
pub mod dag2svg;
pub mod dag_typechecker;
pub mod from_egglog;
Expand Down Expand Up @@ -47,7 +51,6 @@ pub fn prologue() -> String {
include_str!("utility/canonicalize.egg"),
include_str!("interval_analysis.egg"),
include_str!("optimizations/switch_rewrites.egg"),
include_str!("optimizations/function_inlining.egg"),
&optimizations::loop_invariant::rules().join("\n"),
include_str!("optimizations/loop_simplify.egg"),
include_str!("optimizations/loop_unroll.egg"),
Expand All @@ -59,6 +62,11 @@ pub fn prologue() -> String {
/// Adds an egglog program to `res` that adds the given term
/// to the database.
/// Returns a fresh variable referring to the program.
/// Note that because the cache caches based on a term, which
/// references the termdag, this cache **cannot** be reused
/// across different TermDags. Make sure to update the term dag
/// for a new term (using TreeToEgglog), rather than creating a
/// new term dag.
fn print_with_intermediate_helper(
termdag: &TermDag,
term: Term,
Expand All @@ -83,6 +91,7 @@ fn print_with_intermediate_helper(
let fresh_var = format!("__tmp{}", cache.len());
writeln!(res, "(let {fresh_var} ({head} {child_vars}))").unwrap();
cache.insert(term, fresh_var.clone());

fresh_var
}
}
Expand All @@ -96,12 +105,59 @@ pub(crate) fn print_with_intermediate_vars(termdag: &TermDag, term: Term) -> Str
printed
}

// Returns a formatted string of (union call body) for each pair
fn print_function_inlining_pairs(
function_inlining_pairs: Vec<function_inlining::CallBody>,
printed: &mut String,
tree_state: &mut TreeToEgglog,
term_cache: &mut HashMap<Term, String>,
) -> String {
// Get unions
let unions = function_inlining_pairs
.iter()
.map(|cb| {
let call_term = cb.call.to_egglog_internal(tree_state);
let body_term = cb.body.to_egglog_internal(tree_state);
format!(
"(union {} {})",
print_with_intermediate_helper(&tree_state.termdag, call_term, term_cache, printed),
print_with_intermediate_helper(&tree_state.termdag, body_term, term_cache, printed)
)
})
.collect::<Vec<_>>()
.join("\n");
unions
}

// It is expected that program has context added
pub fn build_program(program: &TreeProgram) -> String {
let (term, termdag) = program.to_egglog();
let printed = print_with_intermediate_vars(&termdag, term);
format!("{}\n{printed}\n{}\n", prologue(), mk_schedule(),)
let mut printed = String::new();

// Create a global cache for generating intermediate variables
let mut tree_state = TreeToEgglog::new();
let mut term_cache = HashMap::<Term, String>::new();

// Generate function inlining egglog
let function_inlining = print_function_inlining_pairs(
function_inlining::function_inlining_pairs(program, config::FUNCTION_INLINING_ITERATIONS),
&mut printed,
&mut tree_state,
&mut term_cache,
);

// Generate program egglog
let term = program.to_egglog_internal(&mut tree_state);
let res =
print_with_intermediate_helper(&tree_state.termdag, term, &mut term_cache, &mut printed);

format!(
"{}\n{printed}\n(let PROG {res})\n\n{function_inlining}\n{}\n",
prologue(),
mk_schedule()
)
}

// It is expected that program has context added
pub fn optimize(program: &TreeProgram) -> std::result::Result<TreeProgram, egglog::Error> {
let egglog_prog = build_program(program);
let mut egraph = egglog::EGraph::default();
Expand Down
90 changes: 90 additions & 0 deletions dag_in_context/src/optimizations/function_inlining.rs
Original file line number Diff line number Diff line change
@@ -1 +1,91 @@
use std::{
collections::{HashMap, HashSet},
rc::Rc,
vec,
};

use crate::schema::{Expr, RcExpr, TreeProgram};

#[derive(Clone, PartialEq, PartialOrd, Eq, Ord)]
pub struct CallBody {
pub call: RcExpr,
pub body: RcExpr,
}

// Gets a set of all the calls in the program
fn get_calls(expr: &RcExpr) -> Vec<RcExpr> {
// Get calls from children
let mut calls = if !expr.children_exprs().is_empty() {
expr.children_exprs()
.iter()
.flat_map(get_calls)
.collect::<Vec<_>>()
} else {
Vec::new()
};

// Add to set if this is a call
if let Expr::Call(_, _) = expr.as_ref() {
calls.push(expr.clone());
}

calls
}

// Pairs a call with its equivalent inlined body, using the passed-in function -> body map
// to look up the body
fn subst_call(call: &RcExpr, func_to_body: &HashMap<String, &RcExpr>) -> CallBody {
if let Expr::Call(func_name, args) = call.as_ref() {
CallBody {
call: call.clone(),
body: Expr::subst(args, func_to_body[func_name]),
}
} else {
panic!("Tried to substitute non-calls.")
}
}

// Generates a list of (call, body) pairs (in a CallBody) that can be unioned
pub fn function_inlining_pairs(program: &TreeProgram, iterations: usize) -> Vec<CallBody> {
let mut all_funcs = vec![program.entry.clone()];
all_funcs.extend(program.functions.clone());

// Make func name -> body map
let func_name_to_body = all_funcs
.iter()
.map(|func| {
(
func.func_name().expect("Func has name"),
func.func_body().expect("Func has body"),
)
})
.collect::<HashMap<String, &RcExpr>>();

// Inline once
// Keep track of all calls we've seen so far to avoid duplication
let mut prev_calls: HashSet<*const Expr> = HashSet::new();
let mut prev_inlining = all_funcs
.iter()
.flat_map(get_calls)
// Deduplicate calls before substitution
.filter(|call| prev_calls.insert(Rc::as_ptr(call)))
// We cannot hash RcExprs because it is too slow
.map(|call| subst_call(&call, &func_name_to_body))
.collect::<Vec<_>>();

let mut all_inlining = prev_inlining.clone();

// Repeat! Get calls and subst for each new substituted body.
for _ in 1..iterations {
let next_inlining = prev_inlining
.iter()
.flat_map(|cb| get_calls(&cb.body))
.filter(|call| prev_calls.insert(Rc::as_ptr(call)))
.map(|call| subst_call(&call, &func_name_to_body))
.collect::<Vec<_>>();
all_inlining.extend(next_inlining.clone());
prev_inlining = next_inlining;
}

all_inlining
}
2 changes: 2 additions & 0 deletions dag_in_context/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ pub enum Constant {
/// We want sharing between sub-expressions, so we use Rc instead of Box.
/// Invariant: Every shared sub-expression is re-used by the same Rc<Expr> (pointer equality).
/// This is important for the correctness of the interpreter, which makes this assumption.
/// NOTE: Please do not hash this. Hash a *const Expr instead. The hash function for RcExpr
/// is very slow due to sharing of subexpressions.
pub type RcExpr = Rc<Expr>;

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
Expand Down
Loading
Loading