Skip to content

Commit

Permalink
Merge pull request #2 from kaist-plrg/simpl
Browse files Browse the repository at this point in the history
Simpl
  • Loading branch information
Medowhill authored Nov 5, 2024
2 parents 4f52a21 + bda75ea commit 97f4dc2
Show file tree
Hide file tree
Showing 4 changed files with 671 additions and 165 deletions.
206 changes: 197 additions & 9 deletions src/ai/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::{
path::Path,
};

use compile_util::LoHi;
use etrace::some_or;
use rustc_abi::VariantIdx;
use rustc_hir::{
Expand All @@ -14,7 +15,7 @@ use rustc_hir::{
use rustc_index::bit_set::BitSet;
use rustc_middle::{
hir::nested_filter,
mir::{BasicBlock, Body, Local, Location, TerminatorKind},
mir::{BasicBlock, Body, Local, Location, StatementKind, TerminatorKind},
ty::{AdtKind, Ty, TyCtxt, TyKind, TypeAndMut},
};
use rustc_session::config::Input;
Expand Down Expand Up @@ -47,7 +48,24 @@ impl Default for AnalysisConfig {
}
}

pub type AnalysisResult = BTreeMap<String, Vec<OutputParam>>;
pub type AnalysisResult = BTreeMap<String, FnAnalysisRes>;

#[derive(Debug, Serialize, Deserialize)]
pub struct FnAnalysisRes {
pub output_params: Vec<OutputParam>,
pub wbrs: Vec<WriteBeforeReturn>,
pub rcfws: Rcfws,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct WriteBeforeReturn {
pub span: LoHi,
pub mays: BTreeSet<usize>,
pub musts: BTreeSet<usize>,
}

// Removable checks for Write s
pub type Rcfws = BTreeMap<usize, BTreeSet<LoHi>>;

pub fn analyze_path(path: &Path, conf: &AnalysisConfig) -> AnalysisResult {
analyze_input(compile_util::path_to_input(path), conf)
Expand All @@ -62,16 +80,18 @@ pub fn analyze_input(input: Input, conf: &AnalysisConfig) -> AnalysisResult {
compile_util::run_compiler(config, |tcx| {
analyze(tcx, conf)
.into_iter()
.filter_map(|(def_id, (_, params))| {
if params.is_empty() {
.filter_map(|(def_id, (_, res))| {
if res.output_params.is_empty() {
None
} else {
Some((tcx.def_path_str(def_id), params))
Some((tcx.def_path_str(def_id), res))
}
})
.collect()
.collect::<Vec<_>>()
})
.unwrap()
.into_iter()
.collect()
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
Expand All @@ -84,7 +104,7 @@ enum Write {
pub fn analyze(
tcx: TyCtxt<'_>,
conf: &AnalysisConfig,
) -> BTreeMap<DefId, (FunctionSummary, Vec<OutputParam>)> {
) -> BTreeMap<DefId, (FunctionSummary, FnAnalysisRes)> {
let hir = tcx.hir();

let mut call_graph = BTreeMap::new();
Expand Down Expand Up @@ -150,6 +170,13 @@ pub fn analyze(
let mut wm_map = BTreeMap::new();
let mut call_args_map = BTreeMap::new();
let mut analysis_times: BTreeMap<_, u128> = BTreeMap::new();

let mut wbrs: BTreeMap<DefId, Vec<WriteBeforeReturn>> = BTreeMap::new();
let mut bb_musts: BTreeMap<DefId, BTreeMap<BasicBlock, BTreeSet<usize>>> = BTreeMap::new();
let mut is_units = BTreeMap::new();

let mut rcfws = BTreeMap::new();

for id in &po {
let def_ids = &elems[id];
let recursive = if def_ids.len() == 1 {
Expand Down Expand Up @@ -199,7 +226,61 @@ pub fn analyze(
.flat_map(|p| analyzer.expands_path(&AbsPath(vec![*p])))
.collect();

let mut return_states = return_location(body)
let ret_location = return_location(body);

let mut wbr = vec![];
let mut bb_must = BTreeMap::new();

let mut stack = vec![];

if let Some(ret_location) = ret_location {
if let Some((ret_loc_assign0, ret_loc)) =
exists_assign0(body, ret_location.block)
{
stack.push((ret_loc, ret_loc_assign0));
} else if let Some(v) = body.basic_blocks.predecessors().get(ret_location.block)
{
for i in v {
if let Some((sp, ret_loc)) = exists_assign0(body, *i) {
stack.push((ret_loc, sp));
}
}
}

let empty_map = BTreeMap::new();
for (loc, sp) in stack.iter() {
let writes: Vec<_> = states
.get(loc)
.unwrap_or(&empty_map)
.values()
.map(|st| st.writes.as_set())
.collect();
let musts: BTreeSet<_> = writes
.iter()
.copied()
.fold(None, |acc: Option<BTreeSet<_>>, ws| {
Some(match acc {
Some(acc) => acc.intersection(ws).cloned().collect(),
None => ws.clone(),
})
})
.unwrap_or_default()
.iter()
.map(|p| p.base() - 1)
.collect();
let mays: BTreeSet<_> =
writes.into_iter().flatten().map(|p| p.base() - 1).collect();
let span = LoHi::from_span(*sp);
bb_must.insert(loc.block, musts.clone());
wbr.push(WriteBeforeReturn { span, mays, musts });
}
}

wbrs.insert(*def_id, wbr);
bb_musts.insert(*def_id, bb_must);
is_units.insert(*def_id, stack.is_empty());

let mut return_states = ret_location
.and_then(|ret| states.get(&ret))
.cloned()
.unwrap_or_default();
Expand Down Expand Up @@ -240,12 +321,74 @@ pub fn analyze(
for p in &mut output_params {
analyzer.find_complete_write(p, &result, &writes_map, &call_args, *def_id);
}

let body = tcx.optimized_mir(*def_id);
let bb_must = &bb_musts[def_id];
let mut rcfw: Rcfws = BTreeMap::new();

if !is_units[def_id] {
for p in &output_params {
let OutputParam {
index,
complete_writes,
..
} = p;
for complete_write in complete_writes {
let CompleteWrite {
block,
statement_index,
..
} = complete_write;

let mut stack = vec![BasicBlock::from_usize(*block)];
let mut visited: BTreeSet<_> = stack.iter().cloned().collect();

let always_write = loop {
if let Some(bb) = stack.pop() {
if let Some(musts) = bb_must.get(&bb) {
if !musts.contains(index) {
break false;
}
}

let term = body.basic_blocks[bb].terminator();
for bb in term.successors() {
if !visited.contains(&bb) {
visited.insert(bb);
stack.push(bb);
}
}
} else {
break true;
}
};

if always_write {
let location = Location {
block: BasicBlock::from_usize(*block),
statement_index: *statement_index,
};
let span = LoHi::from_span(body.source_info(location).span);
let entry = rcfw.entry(*index);
entry.or_default().insert(span);
}
}
}
}

rcfws.insert(*def_id, rcfw);
output_params_map.insert(*def_id, output_params);
}
break;
}
}
}

if conf.max_loop_head_states <= 1 {
wbrs.clear();
rcfws.clear();
}

if let Some(n) = &conf.function_times {
let mut analysis_times: Vec<_> = analysis_times.into_iter().collect();
analysis_times.sort_by_key(|(_, t)| u128::MAX - *t);
Expand All @@ -262,7 +405,14 @@ pub fn analyze(
.into_iter()
.map(|(def_id, summary)| {
let output_params = output_params_map.remove(&def_id).unwrap();
(def_id, (summary, output_params))
let wbrs = wbrs.remove(&def_id).unwrap_or_default();
let rcfws = rcfws.remove(&def_id).unwrap_or_default();
let res = FnAnalysisRes {
output_params,
wbrs,
rcfws,
};
(def_id, (summary, res))
})
.collect()
}
Expand Down Expand Up @@ -1062,6 +1212,44 @@ fn return_location(body: &Body<'_>) -> Option<Location> {
None
}

fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option<(Span, Location)> {
for (i, stmt) in body.basic_blocks[bb].statements.iter().enumerate() {
if let StatementKind::Assign(rb) = &stmt.kind {
if (**rb).0.local.as_u32() == 0u32 {
return Some((
stmt.source_info.span,
Location {
block: bb,
statement_index: i,
},
));
}
}
}
let term = body.basic_blocks[bb].terminator();
if let TerminatorKind::Call {
func: _,
args: _,
destination,
target,
unwind: _,
call_source: _,
fn_span: _,
} = term.kind
{
if destination.local.as_u32() == 0u32 {
return Some((
term.source_info.span,
Location {
block: target.unwrap(),
statement_index: 0,
},
));
}
}
None
}

fn get_rpo_map(body: &Body<'_>) -> BTreeMap<BasicBlock, usize> {
body.basic_blocks
.reverse_postorder()
Expand Down
14 changes: 8 additions & 6 deletions src/bin/nopcrat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ struct Args {

#[arg(short, long)]
transform: bool,
#[arg(long)]
simplify: bool,
#[arg(short, long)]
size: bool,
#[arg(long)]
Expand Down Expand Up @@ -98,9 +100,9 @@ fn main() {
};

if args.verbose {
for (func, params) in &analysis_result {
for (func, res) in &analysis_result {
println!("{}", func);
for param in params {
for param in &res.output_params {
println!(" {:?}", param);
}
}
Expand All @@ -117,7 +119,7 @@ fn main() {
if args.sample_may || args.sample_must {
let mut params: Vec<_> = analysis_result
.iter()
.filter(|(_, params)| params.iter().any(|p| p.must == args.sample_must))
.filter(|(_, res)| res.output_params.iter().any(|p| p.must == args.sample_must))
.collect();
params.shuffle(&mut thread_rng());
for (f, ps) in params.iter().take(10) {
Expand All @@ -130,11 +132,11 @@ fn main() {
let fns = analysis_result.len();
let musts = analysis_result
.values()
.map(|v| v.iter().filter(|p| p.must).count())
.map(|res| res.output_params.iter().filter(|p| p.must).count())
.sum::<usize>();
let mays = analysis_result
.values()
.map(|v| v.iter().filter(|p| !p.must).count())
.map(|res| res.output_params.iter().filter(|p| !p.must).count())
.sum::<usize>();
println!("{} {} {}", fns, musts, mays);
}
Expand All @@ -148,7 +150,7 @@ fn main() {
return;
}

transform::transform_path(path, &analysis_result);
transform::transform_path(path, &analysis_result, args.simplify);
}

fn clear_dir(path: &Path) {
Expand Down
Loading

0 comments on commit 97f4dc2

Please sign in to comment.