Skip to content

Commit

Permalink
fixed bug in previous commit and added more cases for ___s removal
Browse files Browse the repository at this point in the history
  • Loading branch information
HoseongLee committed Aug 22, 2024
1 parent ff39e65 commit 36284b3
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 122 deletions.
148 changes: 61 additions & 87 deletions src/ai/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use rustc_middle::{
ty::{AdtKind, Ty, TyCtxt, TyKind, TypeAndMut},
};
use rustc_session::config::Input;
use rustc_span::{def_id::DefId, source_map::SourceMap, Span};
use rustc_span::{def_id::DefId, source_map::SourceMap, BytePos, Span};
use serde::{Deserialize, Serialize};

use super::{domains::*, semantics::TransferedTerminator};
Expand Down Expand Up @@ -50,7 +50,7 @@ impl Default for AnalysisConfig {
pub type OutputParams = BTreeMap<String, Vec<OutputParam>>;

// Write Before RETurn s
pub type Wbrets = BTreeMap<i128, BTreeSet<usize>>;
pub type Wbrets = BTreeMap<i128, (BTreeSet<usize>, BTreeSet<usize>)>;

// Removable checks for Write s
pub type Rcfws = BTreeMap<usize, BTreeSet<i128>>;
Expand Down Expand Up @@ -161,7 +161,8 @@ pub fn analyze(
let mut wm_map = BTreeMap::new();
let mut call_args_map = BTreeMap::new();
let mut analysis_times: BTreeMap<_, u128> = BTreeMap::new();
let mut wbrets: BTreeMap<DefId, BTreeMap<i128, BTreeSet<usize>>> = BTreeMap::new();

let mut wbrets: BTreeMap<DefId, Wbrets> = BTreeMap::new();
let mut wbbbrets: BTreeMap<DefId, BTreeMap<BasicBlock, BTreeSet<usize>>> = BTreeMap::new();

let mut rcfws = BTreeMap::new();
Expand Down Expand Up @@ -220,100 +221,76 @@ pub fn analyze(
let mut wbbbret = BTreeMap::new();

if let Some(ret_location) = ret_location {
let mut stack = vec![];

if let Some((ret_loc_assign0, index)) = exists_assign0(body, ret_location.block)
{
let loc = Location {
block: ret_location.block,
statement_index: index,
};

let writes: BTreeSet<_> = states
.get(&loc)
.cloned()
.unwrap_or_default()
.values()
.flat_map(|st| st.writes.as_set())
.map(|p| p.base() - 1)
.collect();

wbret.insert(
unsafe { std::mem::transmute(ret_loc_assign0.data()) },
writes.clone(),
);
wbbbret.insert(ret_location.block, writes);
stack.push((
Location {
block: ret_location.block,
statement_index: index,
},
ret_loc_assign0.data(),
));
} else {
let preds = body.basic_blocks.predecessors().get(ret_location.block);

if let Some(v) = preds {
for i in v {
if let Some((sp, index)) = exists_assign0(body, *i) {
let loc = Location {
block: *i,
statement_index: index,
};

let writes: BTreeSet<_> = states
.get(&loc)
.cloned()
.unwrap_or_default()
.values()
.flat_map(|st| st.writes.as_set())
.map(|p| p.base() - 1)
.collect();

wbret.insert(
unsafe { std::mem::transmute(sp.data()) },
writes.clone(),
);
wbbbret.insert(*i, writes);
stack.push((
Location {
block: *i,
statement_index: index,
},
sp.data(),
));
}
}
}
}
}

let wbret = if let Some(old) = wbrets.get(def_id) {
let spans: BTreeSet<_> = wbret.keys().chain(old.keys()).cloned().collect();
if stack.is_empty() {
let span = body.source_info(ret_location).span;
let pos = span.lo() - BytePos(1);
stack.push((ret_location, span.with_lo(pos).with_hi(pos).data()));
}

spans
.into_iter()
.map(|sp| {
(
sp,
match (wbret.get(&sp), old.get(&sp)) {
(Some(v1), Some(v2)) => {
v1.intersection(v2).cloned().collect::<BTreeSet<usize>>()
for (loc, sp) in stack.iter() {
let must_writes: BTreeSet<_> = states
.get(loc)
.cloned()
.unwrap_or_default()
.values()
.fold(None, |acc: Option<BTreeSet<_>>, st: &AbsState| {
Some(match acc {
Some(acc) => {
acc.intersection(st.writes.as_set()).cloned().collect()
}
(Some(v), None) | (None, Some(v)) => (*v).clone(),
_ => unreachable!(),
},
)
})
.collect::<BTreeMap<i128, _>>()
} else {
wbret
};
None => st.writes.as_set().clone(),
})
})
.unwrap_or_default()
.iter()
.map(|p| p.base() - 1)
.collect();

let wbbbret = if let Some(old) = wbbbrets.get(def_id) {
let keys: BTreeSet<_> = wbbbret.keys().chain(old.keys()).cloned().collect();
let may_writes: BTreeSet<_> = states
.get(loc)
.cloned()
.unwrap_or_default()
.values()
.flat_map(|st| st.writes.as_set())
.map(|p| p.base() - 1)
.collect();

keys.into_iter()
.map(|bb| {
(
bb,
match (wbbbret.get(&bb), old.get(&bb)) {
(Some(v1), Some(v2)) => {
v1.intersection(v2).cloned().collect::<BTreeSet<usize>>()
}
(Some(v), None) | (None, Some(v)) => (*v).clone(),
_ => unreachable!(),
},
)
})
.collect::<BTreeMap<BasicBlock, _>>()
} else {
wbbbret
};
wbret.insert(
unsafe { std::mem::transmute(*sp) },
(may_writes, must_writes.clone()),
);

wbbbret.insert(loc.block, must_writes);
}
}

wbrets.insert(*def_id, wbret);
wbbbrets.insert(*def_id, wbbbret);
Expand Down Expand Up @@ -381,13 +358,10 @@ pub fn analyze(

let success = loop {
if let Some(block) = stack.pop() {
match wbbbret.get(&block) {
Some(ws) => {
if !ws.contains(index) {
break false;
}
if let Some(ws) = wbbbret.get(&block) {
if !ws.contains(index) {
break false;
}
None => (),
}

let bbd = &body.basic_blocks[block];
Expand Down
89 changes: 54 additions & 35 deletions src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,17 @@ fn transform(
wbret
.clone()
.into_iter()
.map(|(sp, params)| {
.map(|(sp, (may, must))| {
(
unsafe { std::mem::transmute::<i128, SpanData>(sp) }.span(),
params
.iter()
.filter_map(|i| index_map.get(i).map(|p| p.name.clone()))
.collect::<Vec<_>>(),
(
may.iter()
.filter_map(|i| index_map.get(i).map(|p| p.name.clone()))
.collect::<Vec<_>>(),
must.iter()
.filter_map(|i| index_map.get(i).map(|p| p.name.clone()))
.collect::<Vec<_>>(),
),
)
})
.collect::<Vec<_>>(),
Expand Down Expand Up @@ -386,11 +390,11 @@ fn transform(

let post_span = post_span.with_hi(post_span.hi() + BytePos(1));
let rv = format!("{}rv___{}", pre_s, post_s);
let (_, wbret) = wbrets[&def_id]
let (_, (may, must)) = wbrets[&def_id]
.iter()
.find(|(sp, _)| span.contains(*sp))
.unwrap();
let rv = func.return_value(Some(rv), wbret);
let rv = func.return_value(Some(rv), may, must);
fix(
post_span,
format!("; return {};{}", rv, if arm { " }" } else { "" }),
Expand Down Expand Up @@ -439,10 +443,33 @@ fn transform(
let ret_ty = func.return_type(orig);
fix(span, format!("-> {}", ret_ty));

let mut unremovable = BTreeSet::new();
for param in func.params() {
let rcfw = &rcfws[&def_id].get(&param.name);

for span in &param.writes {
if call_spans.contains(span) {
continue;
}
if let Some(rcfw) = rcfw {
if rcfw.iter().any(|sp| span.contains(*sp)) {
continue;
}
}

unremovable.insert(&param.name);

let pos = span.hi() + BytePos(1);
let span = span.with_hi(pos).with_lo(pos);
let assign = format!("{0}___s = true;", param.name);
fix(span, assign);
}
}

let local_vars: String = func
.params()
.map(|param| {
if param.must {
if param.must || !unremovable.contains(&param.name) {
if passes.contains(&param.name) {
format!(
"
Expand Down Expand Up @@ -480,26 +507,6 @@ fn transform(
let span = body.value.span.with_lo(pos).with_hi(pos);
fix(span, local_vars);

for param in func.params() {
let rcfw = &rcfws[&def_id].get(&param.name);

for span in &param.writes {
if call_spans.contains(span) {
continue;
}
if let Some(rcfw) = rcfw {
if rcfw.iter().any(|sp| span.contains(*sp)) {
continue;
}
}

let pos = span.hi() + BytePos(1);
let span = span.with_hi(pos).with_lo(pos);
let assign = format!("{0}___s = true;", param.name);
fix(span, assign);
}
}

for param in func.params() {
if let Some(spans) = ref_to_spans.get(&param.name) {
for span in spans {
Expand Down Expand Up @@ -534,7 +541,11 @@ fn transform(
let post_s = source_map.span_to_snippet(*s).unwrap();
rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s);
}
let rv = func.return_value(Some(rv), &[]);
let (_, (may, must)) = wbrets[&def_id]
.iter()
.find(|(sp, _)| span.contains(*sp))
.unwrap();
let rv = func.return_value(Some(rv), may, must);

fix(*span, format!("return {}", rv));
}
Expand All @@ -545,18 +556,22 @@ fn transform(
continue;
}
let orig = value.map(|value| source_map.span_to_snippet(value).unwrap());
let (_, wbret) = wbrets[&def_id]
let (_, (may, must)) = wbrets[&def_id]
.iter()
.find(|(sp, _)| span.contains(*sp))
.unwrap();
let ret_v = func.return_value(orig, wbret);
let ret_v = func.return_value(orig, may, must);
fix(span, format!("return {}", ret_v));
}

if func.is_unit {
let pos = body.value.span.hi() - BytePos(1);
let span = body.value.span.with_lo(pos).with_hi(pos);
let ret_v = func.return_value(None, &[]);
let (_, (may, must)) = wbrets[&def_id]
.iter()
.find(|(sp, _)| span.contains(*sp))
.unwrap();
let ret_v = func.return_value(None, may, must);
fix(span, ret_v);
}
}
Expand Down Expand Up @@ -720,13 +735,15 @@ impl Func {
}
}

fn return_value(&self, orig: Option<String>, wbret: &[String]) -> String {
fn return_value(&self, orig: Option<String>, may: &[String], must: &[String]) -> String {
let mut values = vec![];
if let Some((_, i)) = &self.first_return {
let orig = orig.unwrap();
let param = &self.index_map[i];
let v = if wbret.contains(&param.name) {
let v = if must.contains(&param.name) {
format!("Ok({}___v)", param.name)
} else if !may.contains(&param.name) {
format!("Err({})", orig)
} else {
format!(
"if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}",
Expand All @@ -741,8 +758,10 @@ impl Func {
let param = &self.index_map[i];
let v = if param.must {
format!("{}___v", param.name)
} else if wbret.contains(&param.name) {
} else if must.contains(&param.name) {
format!("Some ({}___v)", param.name)
} else if !may.contains(&param.name) {
"None".to_string()
} else {
format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name)
};
Expand Down

0 comments on commit 36284b3

Please sign in to comment.