Skip to content

Commit

Permalink
count other simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed Nov 5, 2024
1 parent 3c16ea6 commit bda75ea
Showing 1 changed file with 57 additions and 67 deletions.
124 changes: 57 additions & 67 deletions src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ use rustfix::Suggestion;

use crate::{ai::analysis::*, compile_util};

static mut N_MUST: usize = 0;
static mut N_MAY: usize = 0;
static mut N_DIRECT_RETURNS: usize = 0;
static mut N_REMOVED_CHECKS: usize = 0;

#[derive(Default, Clone, Copy)]
struct Counter {
simplify: bool,
removed_pointer_defs: usize,
removed_pointer_uses: usize,
direct_returns: usize,
success_returns: usize,
failure_returns: usize,
removed_flag_sets: usize,
removed_flag_defs: usize,
}

pub fn transform_path(path: &Path, analysis_result: &AnalysisResult, simplify: bool) {
Expand All @@ -44,7 +45,10 @@ fn transform(
analysis_result: &AnalysisResult,
simplify: bool,
) -> BTreeMap<PathBuf, Vec<Suggestion>> {
let mut counter = Counter::default();
let mut counter = Counter {
simplify,
..Counter::default()
};

let hir = tcx.hir();
let source_map = tcx.sess.source_map();
Expand Down Expand Up @@ -321,7 +325,7 @@ fn transform(
let assign_map = curr.map(|c| c.assign_map(span)).unwrap_or_default();

let mut mtch = func.first_return.and_then(|(_, first)| {
let set_flag = generate_set_flag(&span, &first, rcfw, &assign_map);
let set_flag = generate_set_flag(&span, &first, rcfw, &assign_map, &mut counter);
func.call_match(&args, set_flag)
});

Expand All @@ -335,7 +339,7 @@ fn transform(
let fail = "Err(_) => ";
let (_, i) = func.first_return.as_ref().unwrap();
let arg = &args[*i];
let set_flag = generate_set_flag(&span, i, rcfw, &assign_map);
let set_flag = generate_set_flag(&span, i, rcfw, &assign_map, &mut counter);
let assign = if arg.code.contains("&mut ") {
format!(" *({}) = v___; {}", arg.code, set_flag)
} else {
Expand Down Expand Up @@ -408,6 +412,7 @@ fn transform(
.find(|(sp, _)| span.contains(*sp))
.map(|r| &r.1),
None,
&mut counter,
);
fix(
post_span,
Expand All @@ -425,7 +430,12 @@ fn transform(
let set_flags = func
.remaining_return
.iter()
.map(|i| (*i, generate_set_flag(&span, i, rcfw, &assign_map)))
.map(|i| {
(
*i,
generate_set_flag(&span, i, rcfw, &assign_map, &mut counter),
)
})
.collect();

let mut assign = func.call_assign(&args, &set_flags);
Expand Down Expand Up @@ -458,7 +468,8 @@ fn transform(

for span in &param.writes {
if let Some(rcfw) = rcfw {
if rcfw.iter().any(|sp| span.contains(*sp)) {
if rcfw.iter().any(|sp| span.contains(*sp)) && simplify {
counter.removed_flag_sets += 1;
continue;
}
}
Expand All @@ -469,10 +480,6 @@ fn transform(
continue;
}

unsafe {
N_REMOVED_CHECKS += 1;
}

let pos = span.hi() + BytePos(1);
let span = span.with_hi(pos).with_lo(pos);
let assign = format!("{0}___s = true;", param.name);
Expand All @@ -483,7 +490,10 @@ fn transform(
let local_vars: String = func
.params()
.map(|param| {
if param.must || (!unremovable.contains(&param.name)) {
if param.must || (!unremovable.contains(&param.name) && simplify) {
if !param.must {
counter.removed_flag_defs += 1;
}
if passes.contains(&param.name) || !simplify {
format!(
"
Expand Down Expand Up @@ -533,22 +543,18 @@ fn transform(

let mut assign_before_ret = None;
for assign in visitor.assigns.iter() {
let Assign {
name: _,
value: _,
span: sp,
} = assign;
if visitor.calls.iter().any(|call| call.span.overlaps(*sp)) {
if visitor
.calls
.iter()
.any(|call| call.span.overlaps(assign.span))
{
continue;
}
let sp2 = sp.between(*span);
if source_map
.span_to_snippet(sp2)
.span_to_snippet(assign.span.between(*span))
.unwrap()
.chars()
.filter(|c| !c.is_whitespace())
.count()
== 0
.all(|c| c.is_whitespace())
{
assign_before_ret = Some(assign);
break;
Expand Down Expand Up @@ -578,6 +584,7 @@ fn transform(
.find(|(sp, _)| span.contains(*sp))
.map(|r| &r.1),
lit_map,
&mut counter,
);
fix(*span, format!("return {}", ret_v));
}
Expand Down Expand Up @@ -627,6 +634,7 @@ fn transform(
.find(|(sp, _)| span.contains(*sp))
.map(|r| &r.1),
None,
&mut counter,
);

fix(span, format!("return {}", rv));
Expand All @@ -649,7 +657,7 @@ fn transform(
}

if !skip {
let ret_v = func.return_value(None, None, None);
let ret_v = func.return_value(None, None, None, &mut counter);
fix(span, format!("\t{}\n", ret_v));
}
}
Expand All @@ -667,17 +675,11 @@ fn transform(

println!("Removed pointer defs: {}", counter.removed_pointer_defs);
println!("Removed pointer uses: {}", counter.removed_pointer_uses);
println!("Must write before return simplifications : {}", unsafe {
N_MUST
});
println!(
"Must not write before return simplifications : {}",
unsafe { N_MAY }
);
println!("Direct return simplifications : {}", unsafe {
N_DIRECT_RETURNS
});
println!("Removed checks: {}", unsafe { N_REMOVED_CHECKS });
println!("Direct returns: {}", counter.direct_returns);
println!("Success returns: {}", counter.success_returns);
println!("Failure returns: {}", counter.failure_returns);
println!("Removed flag sets: {}", counter.removed_flag_sets);
println!("Removed flag defs: {}", counter.removed_flag_defs);

suggestions
}
Expand Down Expand Up @@ -834,33 +836,28 @@ impl Func {
orig: Option<String>,
wbr: Option<&(Vec<String>, Vec<String>)>,
lit_map: Option<(&String, &String)>,
counter: &mut Counter,
) -> String {
let mut values = vec![];
if let Some((_, i)) = &self.first_return {
let orig = orig.unwrap();
let param = &self.index_map[i];
let name = lit_map
.and_then(|(n, v)| {
if *n == param.name {
unsafe {
N_DIRECT_RETURNS += 1;
}
if *n == param.name && counter.simplify {
counter.direct_returns += 1;
Some((*v).clone())
} else {
None
}
})
.unwrap_or(format!("{}___v", param.name));
let v = if let Some((may, must)) = wbr {
if must.contains(&param.name) {
unsafe {
N_MUST += 1;
}
if must.contains(&param.name) && counter.simplify {
counter.success_returns += 1;
format!("Ok({})", name)
} else if !may.contains(&param.name) {
unsafe {
N_MAY += 1;
}
} else if !may.contains(&param.name) && counter.simplify {
counter.failure_returns += 1;
format!("Err({})", orig)
} else {
format!(
Expand All @@ -882,10 +879,8 @@ impl Func {
let param = &self.index_map[i];
let name = lit_map
.and_then(|(n, v)| {
if *n == param.name {
unsafe {
N_DIRECT_RETURNS += 1;
}
if *n == param.name && counter.simplify {
counter.direct_returns += 1;
Some((*v).clone())
} else {
None
Expand All @@ -895,15 +890,11 @@ impl Func {
let v = if param.must {
name
} else if let Some((may, must)) = wbr {
if must.contains(&param.name) {
unsafe {
N_MUST += 1;
}
format!("Some ({})", name)
} else if !may.contains(&param.name) {
unsafe {
N_MAY += 1;
}
if must.contains(&param.name) && counter.simplify {
counter.success_returns += 1;
format!("Some({})", name)
} else if !may.contains(&param.name) && counter.simplify {
counter.failure_returns += 1;
"None".to_string()
} else {
format!(
Expand Down Expand Up @@ -1270,14 +1261,13 @@ fn generate_set_flag(
i: &usize,
rcfws: &BTreeMap<String, Vec<Span>>,
assign_map: &BTreeMap<usize, String>,
counter: &mut Counter,
) -> String {
if let Some(arg) = assign_map.get(i) {
let rcfw = &rcfws.get(arg);
if let Some(rcfw) = rcfw {
if rcfw.iter().any(|sp| span.contains(*sp)) {
unsafe {
N_REMOVED_CHECKS += 1;
}
if rcfw.iter().any(|sp| span.contains(*sp)) && counter.simplify {
counter.removed_flag_sets += 1;
return "".to_string();
}
}
Expand Down

0 comments on commit bda75ea

Please sign in to comment.