diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 24ac045..c4b5916 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -4,6 +4,7 @@ use std::{ path::Path, }; +use compile_util::LoHi; use etrace::some_or; use rustc_abi::VariantIdx; use rustc_hir::{ @@ -14,7 +15,7 @@ use rustc_hir::{ use rustc_index::bit_set::BitSet; use rustc_middle::{ hir::nested_filter, - mir::{BasicBlock, Body, Local, Location, TerminatorKind}, + mir::{BasicBlock, Body, Local, Location, StatementKind, TerminatorKind}, ty::{AdtKind, Ty, TyCtxt, TyKind, TypeAndMut}, }; use rustc_session::config::Input; @@ -47,7 +48,24 @@ impl Default for AnalysisConfig { } } -pub type AnalysisResult = BTreeMap>; +pub type AnalysisResult = BTreeMap; + +#[derive(Debug, Serialize, Deserialize)] +pub struct FnAnalysisRes { + pub output_params: Vec, + pub wbrs: Vec, + pub rcfws: Rcfws, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct WriteBeforeReturn { + pub span: LoHi, + pub mays: BTreeSet, + pub musts: BTreeSet, +} + +// Removable checks for Write s +pub type Rcfws = BTreeMap>; pub fn analyze_path(path: &Path, conf: &AnalysisConfig) -> AnalysisResult { analyze_input(compile_util::path_to_input(path), conf) @@ -62,16 +80,18 @@ pub fn analyze_input(input: Input, conf: &AnalysisConfig) -> AnalysisResult { compile_util::run_compiler(config, |tcx| { analyze(tcx, conf) .into_iter() - .filter_map(|(def_id, (_, params))| { - if params.is_empty() { + .filter_map(|(def_id, (_, res))| { + if res.output_params.is_empty() { None } else { - Some((tcx.def_path_str(def_id), params)) + Some((tcx.def_path_str(def_id), res)) } }) - .collect() + .collect::>() }) .unwrap() + .into_iter() + .collect() } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -84,7 +104,7 @@ enum Write { pub fn analyze( tcx: TyCtxt<'_>, conf: &AnalysisConfig, -) -> BTreeMap)> { +) -> BTreeMap { let hir = tcx.hir(); let mut call_graph = BTreeMap::new(); @@ -150,6 +170,13 @@ pub fn analyze( let mut wm_map = BTreeMap::new(); let mut call_args_map = BTreeMap::new(); let mut analysis_times: BTreeMap<_, u128> = BTreeMap::new(); + + let mut wbrs: BTreeMap> = BTreeMap::new(); + let mut bb_musts: BTreeMap>> = BTreeMap::new(); + let mut is_units = BTreeMap::new(); + + let mut rcfws = BTreeMap::new(); + for id in &po { let def_ids = &elems[id]; let recursive = if def_ids.len() == 1 { @@ -199,7 +226,61 @@ pub fn analyze( .flat_map(|p| analyzer.expands_path(&AbsPath(vec![*p]))) .collect(); - let mut return_states = return_location(body) + let ret_location = return_location(body); + + let mut wbr = vec![]; + let mut bb_must = BTreeMap::new(); + + let mut stack = vec![]; + + if let Some(ret_location) = ret_location { + if let Some((ret_loc_assign0, ret_loc)) = + exists_assign0(body, ret_location.block) + { + stack.push((ret_loc, ret_loc_assign0)); + } else if let Some(v) = body.basic_blocks.predecessors().get(ret_location.block) + { + for i in v { + if let Some((sp, ret_loc)) = exists_assign0(body, *i) { + stack.push((ret_loc, sp)); + } + } + } + + let empty_map = BTreeMap::new(); + for (loc, sp) in stack.iter() { + let writes: Vec<_> = states + .get(loc) + .unwrap_or(&empty_map) + .values() + .map(|st| st.writes.as_set()) + .collect(); + let musts: BTreeSet<_> = writes + .iter() + .copied() + .fold(None, |acc: Option>, ws| { + Some(match acc { + Some(acc) => acc.intersection(ws).cloned().collect(), + None => ws.clone(), + }) + }) + .unwrap_or_default() + .iter() + .map(|p| p.base() - 1) + .collect(); + let mays: BTreeSet<_> = + writes.into_iter().flatten().map(|p| p.base() - 1).collect(); + let span = LoHi::from_span(*sp); + bb_must.insert(loc.block, musts.clone()); + wbr.push(WriteBeforeReturn { span, mays, musts }); + } + } + + wbrs.insert(*def_id, wbr); + bb_musts.insert(*def_id, bb_must); + is_units.insert(*def_id, stack.is_empty()); + + let mut return_states = ret_location .and_then(|ret| states.get(&ret)) .cloned() .unwrap_or_default(); @@ -240,12 +321,74 @@ pub fn analyze( for p in &mut output_params { analyzer.find_complete_write(p, &result, &writes_map, &call_args, *def_id); } + + let body = tcx.optimized_mir(*def_id); + let bb_must = &bb_musts[def_id]; + let mut rcfw: Rcfws = BTreeMap::new(); + + if !is_units[def_id] { + for p in &output_params { + let OutputParam { + index, + complete_writes, + .. + } = p; + for complete_write in complete_writes { + let CompleteWrite { + block, + statement_index, + .. + } = complete_write; + + let mut stack = vec![BasicBlock::from_usize(*block)]; + let mut visited: BTreeSet<_> = stack.iter().cloned().collect(); + + let always_write = loop { + if let Some(bb) = stack.pop() { + if let Some(musts) = bb_must.get(&bb) { + if !musts.contains(index) { + break false; + } + } + + let term = body.basic_blocks[bb].terminator(); + for bb in term.successors() { + if !visited.contains(&bb) { + visited.insert(bb); + stack.push(bb); + } + } + } else { + break true; + } + }; + + if always_write { + let location = Location { + block: BasicBlock::from_usize(*block), + statement_index: *statement_index, + }; + let span = LoHi::from_span(body.source_info(location).span); + let entry = rcfw.entry(*index); + entry.or_default().insert(span); + } + } + } + } + + rcfws.insert(*def_id, rcfw); output_params_map.insert(*def_id, output_params); } break; } } } + + if conf.max_loop_head_states <= 1 { + wbrs.clear(); + rcfws.clear(); + } + if let Some(n) = &conf.function_times { let mut analysis_times: Vec<_> = analysis_times.into_iter().collect(); analysis_times.sort_by_key(|(_, t)| u128::MAX - *t); @@ -262,7 +405,14 @@ pub fn analyze( .into_iter() .map(|(def_id, summary)| { let output_params = output_params_map.remove(&def_id).unwrap(); - (def_id, (summary, output_params)) + let wbrs = wbrs.remove(&def_id).unwrap_or_default(); + let rcfws = rcfws.remove(&def_id).unwrap_or_default(); + let res = FnAnalysisRes { + output_params, + wbrs, + rcfws, + }; + (def_id, (summary, res)) }) .collect() } @@ -1062,6 +1212,44 @@ fn return_location(body: &Body<'_>) -> Option { None } +fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option<(Span, Location)> { + for (i, stmt) in body.basic_blocks[bb].statements.iter().enumerate() { + if let StatementKind::Assign(rb) = &stmt.kind { + if (**rb).0.local.as_u32() == 0u32 { + return Some(( + stmt.source_info.span, + Location { + block: bb, + statement_index: i, + }, + )); + } + } + } + let term = body.basic_blocks[bb].terminator(); + if let TerminatorKind::Call { + func: _, + args: _, + destination, + target, + unwind: _, + call_source: _, + fn_span: _, + } = term.kind + { + if destination.local.as_u32() == 0u32 { + return Some(( + term.source_info.span, + Location { + block: target.unwrap(), + statement_index: 0, + }, + )); + } + } + None +} + fn get_rpo_map(body: &Body<'_>) -> BTreeMap { body.basic_blocks .reverse_postorder() diff --git a/src/bin/nopcrat.rs b/src/bin/nopcrat.rs index 468ae5a..00a6a95 100644 --- a/src/bin/nopcrat.rs +++ b/src/bin/nopcrat.rs @@ -25,6 +25,8 @@ struct Args { #[arg(short, long)] transform: bool, + #[arg(long)] + simplify: bool, #[arg(short, long)] size: bool, #[arg(long)] @@ -98,9 +100,9 @@ fn main() { }; if args.verbose { - for (func, params) in &analysis_result { + for (func, res) in &analysis_result { println!("{}", func); - for param in params { + for param in &res.output_params { println!(" {:?}", param); } } @@ -117,7 +119,7 @@ fn main() { if args.sample_may || args.sample_must { let mut params: Vec<_> = analysis_result .iter() - .filter(|(_, params)| params.iter().any(|p| p.must == args.sample_must)) + .filter(|(_, res)| res.output_params.iter().any(|p| p.must == args.sample_must)) .collect(); params.shuffle(&mut thread_rng()); for (f, ps) in params.iter().take(10) { @@ -130,11 +132,11 @@ fn main() { let fns = analysis_result.len(); let musts = analysis_result .values() - .map(|v| v.iter().filter(|p| p.must).count()) + .map(|res| res.output_params.iter().filter(|p| p.must).count()) .sum::(); let mays = analysis_result .values() - .map(|v| v.iter().filter(|p| !p.must).count()) + .map(|res| res.output_params.iter().filter(|p| !p.must).count()) .sum::(); println!("{} {} {}", fns, musts, mays); } @@ -148,7 +150,7 @@ fn main() { return; } - transform::transform_path(path, &analysis_result); + transform::transform_path(path, &analysis_result, args.simplify); } fn clear_dir(path: &Path) { diff --git a/src/compile_util.rs b/src/compile_util.rs index 4560401..41f8c32 100644 --- a/src/compile_util.rs +++ b/src/compile_util.rs @@ -22,9 +22,57 @@ use rustc_session::{ use rustc_span::{ edition::Edition, source_map::{FileName, SourceMap}, - RealFileName, Span, + BytePos, RealFileName, Span, SpanData, SyntaxContext, }; use rustfix::{LinePosition, LineRange, Replacement, Snippet, Solution, Suggestion}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct LoHi { + lo: u32, + hi: u32, +} + +impl LoHi { + #[inline] + fn new(lo: u32, hi: u32) -> Self { + Self { lo, hi } + } + + #[inline] + pub fn from_span(span: Span) -> Self { + assert!(span.ctxt().is_root()); + assert!(span.parent().is_none()); + Self::new(span.lo().0, span.hi().0) + } + + #[inline] + pub fn from_span_data(span: SpanData) -> Self { + assert!(span.ctxt.is_root()); + assert!(span.parent.is_none()); + Self::new(span.lo.0, span.hi.0) + } + + #[inline] + pub fn to_span(self) -> Span { + Span::new( + BytePos(self.lo), + BytePos(self.hi), + SyntaxContext::root(), + None, + ) + } + + #[inline] + pub fn to_span_data(self) -> SpanData { + SpanData { + lo: BytePos(self.lo), + hi: BytePos(self.hi), + ctxt: SyntaxContext::root(), + parent: None, + } + } +} pub fn run_compiler) -> R + Send>(config: Config, f: F) -> Option { rustc_driver::catch_fatal_errors(|| { @@ -99,6 +147,13 @@ pub fn span_to_path(span: Span, source_map: &SourceMap) -> Option { pub fn apply_suggestions>(suggestions: &BTreeMap>) { for (path, suggestions) in suggestions { let code = String::from_utf8(fs::read(path).unwrap()).unwrap(); + // for suggestion in suggestions { + // println!("{:?}", path.as_ref()); + // println!("{:?}", suggestion.snippets[0].line_range); + // println!("{:?}", suggestion.snippets[0].range); + // println!("{}", suggestion.solutions[0].replacements[0].replacement); + // println!(); + // } let fixed = rustfix::apply_suggestions(&code, suggestions).unwrap(); fs::write(path, fixed.as_bytes()).unwrap(); } diff --git a/src/transform.rs b/src/transform.rs index d9394e4..e7c2c85 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -19,17 +19,37 @@ use rustfix::Suggestion; use crate::{ai::analysis::*, compile_util}; -pub fn transform_path(path: &Path, params: &BTreeMap>) { +#[derive(Default, Clone, Copy)] +struct Counter { + simplify: bool, + removed_pointer_defs: usize, + removed_pointer_uses: usize, + direct_returns: usize, + success_returns: usize, + failure_returns: usize, + removed_flag_sets: usize, + removed_flag_defs: usize, +} + +pub fn transform_path(path: &Path, analysis_result: &AnalysisResult, simplify: bool) { let input = compile_util::path_to_input(path); let config = compile_util::make_config(input); - let suggestions = compile_util::run_compiler(config, |tcx| transform(tcx, params)).unwrap(); + let suggestions = + compile_util::run_compiler(config, |tcx| transform(tcx, analysis_result, simplify)) + .unwrap(); compile_util::apply_suggestions(&suggestions); } fn transform( tcx: TyCtxt<'_>, - param_map: &BTreeMap>, + analysis_result: &AnalysisResult, + simplify: bool, ) -> BTreeMap> { + let mut counter = Counter { + simplify, + ..Counter::default() + }; + let hir = tcx.hir(); let source_map = tcx.sess.source_map(); @@ -48,6 +68,8 @@ fn transform( } let mut funcs = BTreeMap::new(); + let mut wbrs = BTreeMap::new(); + let mut rcfws = BTreeMap::new(); for id in hir.items() { let item = hir.item(id); let ItemKind::Fn(sig, _, body_id) = item.kind else { @@ -55,10 +77,11 @@ fn transform( }; let def_id = id.owner_id.to_def_id(); let name = tcx.def_path_str(def_id); - let params = some_or!(param_map.get(&name), continue); + let fn_analysis_result = some_or!(analysis_result.get(&name), continue); let body = hir.body(body_id); let mir_body = tcx.optimized_mir(def_id); - let index_map: BTreeMap<_, _> = params + let index_map: BTreeMap<_, _> = fn_analysis_result + .output_params .iter() .map(|param| { let OutputParam { @@ -135,13 +158,51 @@ fn transform( (*index, param) }) .collect(); + + wbrs.insert( + def_id, + fn_analysis_result + .wbrs + .iter() + .map(|wbr| { + ( + wbr.span.to_span(), + ( + wbr.mays + .iter() + .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) + .collect::>(), + wbr.musts + .iter() + .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) + .collect::>(), + ), + ) + }) + .collect::>(), + ); + + rcfws.insert( + def_id, + fn_analysis_result + .rcfws + .iter() + .map(|(index, spans)| { + ( + index_map.get(index).cloned().unwrap().name, + spans.iter().map(|sp| sp.to_span()).collect(), + ) + }) + .collect::>>(), + ); + let hir_id_map: BTreeMap<_, _> = index_map .values() .cloned() .map(|param| (param.hir_id, param)) .collect(); let mut remaining_return: Vec<_> = index_map.keys().copied().collect(); - let first_return = SuccValue::find(params); + let first_return = SuccValue::find(&fn_analysis_result.output_params); if let Some((_, first)) = &first_return { remaining_return.retain(|i| i != first); } @@ -175,56 +236,61 @@ fn transform( let body = hir.body(body_id); let curr = funcs.get(&def_id); + let default = BTreeMap::new(); + let rcfw = rcfws.get(&def_id).unwrap_or(&default); + let mut visitor = BodyVisitor::new(tcx); visitor.visit_body(body); - let mut pass_visitor = PassVisitor::new(tcx); - pass_visitor.visit_body(body); - - let passes = BTreeSet::from_iter(pass_visitor.passes.iter()); + let passes = BTreeSet::from_iter(visitor.passes.iter()); let mut ret_call_spans = BTreeSet::new(); let mut call_spans = BTreeSet::new(); + // in deref expression, pointer name to expression span let mut ref_to_spans: BTreeMap> = BTreeMap::new(); + // in deref expression in return, return span to pointer name and deref span let mut ret_to_ref_spans: BTreeMap> = BTreeMap::new(); + // in deref expression in call, pointer name and deref span let mut ref_and_call_spans = vec![]; - if let Some(func) = curr { - let hirids = BTreeSet::from_iter( - visitor - .calls - .iter() - .filter_map(|c| funcs.get(&c.callee).map(|_| c.hir_id)), - ); - let params = BTreeSet::from_iter(func.params().map(|p| p.name.clone())); - - for rf in visitor.refs { - let Ref { hir_id, span, name } = rf; - if params.contains(&name) && !passes.contains(&name) { - if let Some(expr) = get_parent_call(hir_id, tcx) { - if hirids.contains(&expr.hir_id) { - ref_and_call_spans.push((name, span)); + if simplify { + if let Some(func) = curr { + let hirids = BTreeSet::from_iter( + visitor + .calls + .iter() + .filter_map(|c| funcs.get(&c.callee).map(|_| c.hir_id)), + ); + let params = BTreeSet::from_iter(func.params().map(|p| p.name.clone())); + + for rf in visitor.refs { + let Ref { hir_id, span, name } = rf; + if params.contains(&name) && !passes.contains(&name) { + if let Some(expr) = get_parent_call(hir_id, tcx) { + if hirids.contains(&expr.hir_id) { + ref_and_call_spans.push((name, span)); + continue; + } + } + + if let Some(expr) = get_parent_return(hir_id, tcx) { + ret_to_ref_spans + .entry(expr.span) + .or_default() + .push((name, span)); continue; } - } - if let Some(expr) = get_parent_return(hir_id, tcx) { - ret_to_ref_spans - .entry(expr.span) - .or_default() - .push((name, span)); - continue; + ref_to_spans.entry(name.clone()).or_default().push(span); } - - ref_to_spans.entry(name.clone()).or_default().push(span); } } } - for call in visitor.calls { + for call in visitor.calls.clone() { let Call { hir_id, span, @@ -237,7 +303,7 @@ fn transform( let mut args = args.clone(); for arg in args.iter_mut() { - for (name, span) in ref_and_call_spans.iter() { + for (name, span) in &ref_and_call_spans { if arg.span.contains(*span) { let pre_span = arg.span.with_hi(span.lo()); let post_span = arg.span.with_lo(span.hi()); @@ -246,6 +312,7 @@ fn transform( let post_s = source_map.span_to_snippet(post_span).unwrap(); arg.code = format!("{}{}___v{}", pre_s, name, post_s); + counter.removed_pointer_uses += 1; } } } @@ -256,7 +323,11 @@ fn transform( } let assign_map = curr.map(|c| c.assign_map(span)).unwrap_or_default(); - let mut mtch = func.call_match(&args, &assign_map); + + let mut mtch = func.first_return.and_then(|(_, first)| { + let set_flag = generate_set_flag(&span, &first, rcfw, &assign_map, &mut counter); + func.call_match(&args, set_flag) + }); if let Some(call) = get_if_cmp_call(hir_id, span, tcx) { if let Some(then) = func.cmp(call.op, call.target) { @@ -268,11 +339,7 @@ fn transform( let fail = "Err(_) => "; let (_, i) = func.first_return.as_ref().unwrap(); let arg = &args[*i]; - let set_flag = if let Some(arg) = assign_map.get(i) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; + let set_flag = generate_set_flag(&span, i, rcfw, &assign_map, &mut counter); let assign = if arg.code.contains("&mut ") { format!(" *({}) = v___; {}", arg.code, set_flag) } else { @@ -338,7 +405,15 @@ fn transform( let post_span = post_span.with_hi(post_span.hi() + BytePos(1)); let rv = format!("{}rv___{}", pre_s, post_s); - let rv = func.return_value(Some(rv)); + let rv = func.return_value( + Some(rv), + wbrs[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + None, + &mut counter, + ); fix( post_span, format!("; return {};{}", rv, if arm { " }" } else { "" }), @@ -352,7 +427,18 @@ fn transform( } fix(span.shrink_to_lo(), binding); - let mut assign = func.call_assign(&args, &assign_map); + let set_flags = func + .remaining_return + .iter() + .map(|i| { + ( + *i, + generate_set_flag(&span, i, rcfw, &assign_map, &mut counter), + ) + }) + .collect(); + + let mut assign = func.call_assign(&args, &set_flags); if let Some(m) = &mtch { assign += m; assign += ")"; @@ -376,11 +462,39 @@ fn transform( let ret_ty = func.return_type(orig); fix(span, format!("-> {}", ret_ty)); + let mut unremovable = BTreeSet::new(); + for param in func.params() { + let rcfw = rcfw.get(¶m.name); + + for span in ¶m.writes { + if let Some(rcfw) = rcfw { + if rcfw.iter().any(|sp| span.contains(*sp)) && simplify { + counter.removed_flag_sets += 1; + continue; + } + } + + unremovable.insert(¶m.name); + + if call_spans.contains(span) { + continue; + } + + let pos = span.hi() + BytePos(1); + let span = span.with_hi(pos).with_lo(pos); + let assign = format!("{0}___s = true;", param.name); + fix(span, assign); + } + } + let local_vars: String = func .params() .map(|param| { - if param.must { - if passes.contains(¶m.name) { + if param.must || (!unremovable.contains(¶m.name) && simplify) { + if !param.must { + counter.removed_flag_defs += 1; + } + if passes.contains(¶m.name) || !simplify { format!( " let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ @@ -388,25 +502,27 @@ fn transform( param.name, param.ty, ) } else { + counter.removed_pointer_defs += 1; format!( " let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", param.name, param.ty, ) } - } else if passes.contains(¶m.name) { + } else if passes.contains(¶m.name) || !simplify { format!( " - let mut {0}___s: bool = false; \ - let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ - let mut {0}: *mut {1} = &mut {0}___v;", + let mut {0}___s: bool = false; \ + let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ + let mut {0}: *mut {1} = &mut {0}___v;", param.name, param.ty, ) } else { + counter.removed_pointer_defs += 1; format!( " - let mut {0}___s: bool = false; \ - let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", + let mut {0}___s: bool = false; \ + let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", param.name, param.ty, ) } @@ -417,39 +533,83 @@ fn transform( let span = body.value.span.with_lo(pos).with_hi(pos); fix(span, local_vars); - for param in func.params() { - for span in ¶m.writes { - if call_spans.contains(span) { + for ret in visitor.returns.iter() { + let Return { span, value } = ret; + if ret_call_spans.contains(span) || ret_to_ref_spans.contains_key(span) { + continue; + } + + let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); + + let mut assign_before_ret = None; + for assign in visitor.assigns.iter() { + if visitor + .calls + .iter() + .any(|call| call.span.overlaps(assign.span)) + { continue; } - let pos = span.hi() + BytePos(1); - let span = span.with_hi(pos).with_lo(pos); - let assign = format!("{0}___s = true;", param.name); - fix(span, assign); + if source_map + .span_to_snippet(assign.span.between(*span)) + .unwrap() + .chars() + .all(|c| c.is_whitespace()) + { + assign_before_ret = Some(assign); + break; + } + } + + let mut lit_map = None; + + if let Some(assign_before_ret) = assign_before_ret { + let Assign { + name, + value, + span: sp, + } = assign_before_ret; + + if let Some(spans) = ref_to_spans.get_mut(name) { + spans.retain(|span| !sp.contains(*span)); + lit_map = Some((name, value)); + fix(*sp, "".to_string()); + } } + + let ret_v = func.return_value( + orig, + wbrs[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + lit_map, + &mut counter, + ); + fix(*span, format!("return {}", ret_v)); } for param in func.params() { if let Some(spans) = ref_to_spans.get(¶m.name) { for span in spans { let assign = format!("{}___v", param.name); + counter.removed_pointer_uses += 1; fix(*span, assign); } } } - for (span, ss) in ret_to_ref_spans.iter() { + for (span, mut ss) in ret_to_ref_spans { let mut post_spans = vec![]; - let mut sorted_ss = ss.clone(); - sorted_ss.sort_by_key(|x| x.1); + ss.sort_by_key(|x| x.1); - let s = sorted_ss[0].1; + let s = ss[0].1; let pre_span = span.with_hi(s.lo()).with_lo(span.lo() + BytePos(6)); post_spans.push(span.with_lo(s.hi())); - for (i, (_, s)) in sorted_ss[1..].iter().enumerate() { + for (i, (_, s)) in ss[1..].iter().enumerate() { post_spans[i] = post_spans[i].with_hi(s.lo()); post_spans.push(span.with_lo(s.hi())); } @@ -457,32 +617,49 @@ fn transform( let pre_s = source_map.span_to_snippet(pre_span).unwrap(); let post_s = source_map.span_to_snippet(post_spans[0]).unwrap(); - let mut rv = format!("{}{}___v{}", pre_s, sorted_ss[0].0, post_s); + let mut rv = format!("{}{}___v{}", pre_s, ss[0].0, post_s); + counter.removed_pointer_uses += 1; for (i, s) in post_spans[1..].iter().enumerate() { let post_s = source_map.span_to_snippet(*s).unwrap(); - rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); + rv.push_str(&ss[i + 1].0); + rv.push_str("___v"); + rv.push_str(&post_s); + counter.removed_pointer_uses += 1; } - let rv = func.return_value(Some(rv)); - - fix(*span, format!("return {}", rv)); - } + let rv = func.return_value( + Some(rv), + wbrs[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + None, + &mut counter, + ); - for ret in visitor.returns { - let Return { span, value } = ret; - if ret_call_spans.contains(&span) || ret_to_ref_spans.contains_key(&span) { - continue; - } - let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); - let ret_v = func.return_value(orig); - fix(span, format!("return {}", ret_v)); + fix(span, format!("return {}", rv)); } if func.is_unit { let pos = body.value.span.hi() - BytePos(1); let span = body.value.span.with_lo(pos).with_hi(pos); - let ret_v = func.return_value(None); - fix(span, ret_v); + let pos = pos - BytePos(3); + let prev = span.with_lo(pos).with_hi(pos); + + let mut skip = false; + + for ret in visitor.returns { + let Return { span, value: _ } = ret; + if span.overlaps(prev) { + skip = true; + break; + } + } + + if !skip { + let ret_v = func.return_value(None, None, None, &mut counter); + fix(span, format!("\t{}\n", ret_v)); + } } } suggestions.retain(|_, v| !v.is_empty()); @@ -495,6 +672,15 @@ fn transform( ) }); } + + println!("Removed pointer defs: {}", counter.removed_pointer_defs); + println!("Removed pointer uses: {}", counter.removed_pointer_uses); + println!("Direct returns: {}", counter.direct_returns); + println!("Success returns: {}", counter.success_returns); + println!("Failure returns: {}", counter.failure_returns); + println!("Removed flag sets: {}", counter.removed_flag_sets); + println!("Removed flag defs: {}", counter.removed_flag_defs); + suggestions } @@ -566,34 +752,29 @@ impl Func { map } - fn call_assign(&self, args: &[Arg], assign_map: &BTreeMap) -> String { + fn call_assign(&self, args: &[Arg], set_flags: &BTreeMap) -> String { let mut assigns = vec![]; for i in &self.remaining_return { let arg = &args[*i]; let param = &self.index_map[i]; - let set_flag = if let Some(arg) = assign_map.get(i) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; let assign = if param.must { if arg.code.contains("&mut ") { - format!("*({}) = rv___{}; {}", arg.code, i, set_flag) + format!("*({}) = rv___{}; {}", arg.code, i, set_flags[i]) } else { format!( "if !({0}).is_null() {{ *({0}) = rv___{1}; {2} }}", - arg.code, i, set_flag + arg.code, i, set_flags[i] ) } } else if arg.code.contains("&mut ") { format!( "if let Some(v___) = rv___{} {{ *({}) = v___; {} }}", - i, arg.code, set_flag + i, arg.code, set_flags[i] ) } else { format!( "if !({0}).is_null() {{ if let Some(v___) = rv___{1} {{ *({0}) = v___; {2} }} }}", - arg.code, i, set_flag + arg.code, i, set_flags[i] ) }; assigns.push(assign); @@ -602,14 +783,9 @@ impl Func { mk_string(assigns.iter(), "; ", " ", end) } - fn call_match(&self, args: &[Arg], assign_map: &BTreeMap) -> Option { + fn call_match(&self, args: &[Arg], set_flag: String) -> Option { let (succ_value, first) = &self.first_return?; let arg = &args[*first]; - let set_flag = if let Some(arg) = assign_map.get(first) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; let assign = if arg.code.contains("&mut ") { format!("*({}) = v___; {}", arg.code, set_flag) } else { @@ -655,25 +831,82 @@ impl Func { } } - fn return_value(&self, orig: Option) -> String { + fn return_value( + &self, + orig: Option, + wbr: Option<&(Vec, Vec)>, + lit_map: Option<(&String, &String)>, + counter: &mut Counter, + ) -> String { let mut values = vec![]; if let Some((_, i)) = &self.first_return { let orig = orig.unwrap(); let param = &self.index_map[i]; - let v = format!( - "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", - param.name, orig - ); + let name = lit_map + .and_then(|(n, v)| { + if *n == param.name && counter.simplify { + counter.direct_returns += 1; + Some((*v).clone()) + } else { + None + } + }) + .unwrap_or(format!("{}___v", param.name)); + let v = if let Some((may, must)) = wbr { + if must.contains(¶m.name) && counter.simplify { + counter.success_returns += 1; + format!("Ok({})", name) + } else if !may.contains(¶m.name) && counter.simplify { + counter.failure_returns += 1; + format!("Err({})", orig) + } else { + format!( + "if {0}___s {{ Ok({1}) }} else {{ Err({2}) }}", + param.name, name, orig + ) + } + } else { + format!( + "if {0}___s {{ Ok({1}) }} else {{ Err({2}) }}", + param.name, name, orig + ) + }; values.push(v); } else if let Some(v) = orig { values.push(v); } for i in &self.remaining_return { let param = &self.index_map[i]; + let name = lit_map + .and_then(|(n, v)| { + if *n == param.name && counter.simplify { + counter.direct_returns += 1; + Some((*v).clone()) + } else { + None + } + }) + .unwrap_or(format!("{}___v", param.name)); let v = if param.must { - format!("{}___v", param.name) + name + } else if let Some((may, must)) = wbr { + if must.contains(¶m.name) && counter.simplify { + counter.success_returns += 1; + format!("Some({})", name) + } else if !may.contains(¶m.name) && counter.simplify { + counter.failure_returns += 1; + "None".to_string() + } else { + format!( + "if {0}___s {{ Some({1}) }} else {{ None }}", + param.name, name + ) + } } else { - format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) + format!( + "if {0}___s {{ Some({1}) }} else {{ None }}", + param.name, name + ) }; values.push(v); } @@ -685,13 +918,13 @@ impl Func { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Return { span: Span, value: Option, } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Call { hir_id: HirId, span: Span, @@ -712,44 +945,11 @@ struct Ref { name: String, } -struct PassVisitor<'tcx> { - tcx: TyCtxt<'tcx>, - passes: Vec, -} - -impl<'tcx> PassVisitor<'tcx> { - fn new(tcx: TyCtxt<'tcx>) -> Self { - Self { - tcx, - passes: vec![], - } - } -} - -impl<'tcx> HVisitor<'tcx> for PassVisitor<'tcx> { - type NestedFilter = nested_filter::OnlyBodies; - - fn nested_visit_map(&mut self) -> Self::Map { - self.tcx.hir() - } - - fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { - let source_map = self.tcx.sess.source_map(); - if let ExprKind::Path(p) = expr.kind { - if let Ok(code) = source_map.span_to_snippet(p.qself_span()) { - match get_parent(expr.hir_id, self.tcx) { - Some(e) => { - if let ExprKind::Unary(UnOp::Deref, _) = e.kind { - } else { - self.passes.push(code) - } - } - None => self.passes.push(code), - } - } - } - rustc_hir::intravisit::walk_expr(self, expr); - } +#[derive(Debug)] +struct Assign { + name: String, + value: String, + span: Span, } struct BodyVisitor<'tcx> { @@ -757,6 +957,8 @@ struct BodyVisitor<'tcx> { returns: Vec, calls: Vec, refs: Vec, + passes: Vec, + assigns: Vec, } impl<'tcx> BodyVisitor<'tcx> { @@ -766,6 +968,8 @@ impl<'tcx> BodyVisitor<'tcx> { returns: vec![], calls: vec![], refs: vec![], + passes: vec![], + assigns: vec![], } } } @@ -812,6 +1016,46 @@ impl<'tcx> BodyVisitor<'tcx> { }; self.calls.push(call); } + + fn visit_expr_path(&mut self, expr: &'tcx Expr<'tcx>, path: QPath<'tcx>) { + let source_map = self.tcx.sess.source_map(); + if let Ok(code) = source_map.span_to_snippet(path.qself_span()) { + match get_parent(expr.hir_id, self.tcx) { + Some(e) => { + if let ExprKind::Unary(UnOp::Deref, _) = e.kind { + } else { + self.passes.push(code) + } + } + None => self.passes.push(code), + } + } + } + + fn visit_expr_unary(&mut self, expr: &'tcx Expr<'tcx>, e: &'tcx Expr<'tcx>) { + let source_map = self.tcx.sess.source_map(); + self.refs.push(Ref { + hir_id: expr.hir_id, + span: expr.span, + name: source_map.span_to_snippet(e.span).unwrap(), + }) + } + + fn visit_expr_assign( + &mut self, + expr: &'tcx Expr<'tcx>, + lhs: &'tcx Expr<'tcx>, + rhs: &'tcx Expr<'tcx>, + ) { + let source_map = self.tcx.sess.source_map(); + if let ExprKind::Unary(UnOp::Deref, e) = lhs.kind { + self.assigns.push(Assign { + name: source_map.span_to_snippet(e.span).unwrap(), + value: source_map.span_to_snippet(rhs.span).unwrap(), + span: expr.span.with_hi(expr.span.hi() + BytePos(1)), + }); + } + } } impl<'tcx> HVisitor<'tcx> for BodyVisitor<'tcx> { @@ -822,15 +1066,12 @@ impl<'tcx> HVisitor<'tcx> for BodyVisitor<'tcx> { } fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { - let source_map = self.tcx.sess.source_map(); match expr.kind { - ExprKind::Unary(UnOp::Deref, e) => self.refs.push(Ref { - hir_id: expr.hir_id, - span: expr.span, - name: source_map.span_to_snippet(e.span).unwrap(), - }), ExprKind::Ret(e) => self.visit_expr_ret(expr, e), ExprKind::Call(callee, args) => self.visit_expr_call(expr, callee, args), + ExprKind::Path(path) => self.visit_expr_path(expr, path), + ExprKind::Unary(UnOp::Deref, e) => self.visit_expr_unary(expr, e), + ExprKind::Assign(lhs, rhs, _) => self.visit_expr_assign(expr, lhs, rhs), _ => {} } rustc_hir::intravisit::walk_expr(self, expr); @@ -1014,3 +1255,23 @@ fn mk_string, I: Iterator>( s.push_str(end); s } + +fn generate_set_flag( + span: &Span, + i: &usize, + rcfws: &BTreeMap>, + assign_map: &BTreeMap, + counter: &mut Counter, +) -> String { + if let Some(arg) = assign_map.get(i) { + let rcfw = &rcfws.get(arg); + if let Some(rcfw) = rcfw { + if rcfw.iter().any(|sp| span.contains(*sp)) && counter.simplify { + counter.removed_flag_sets += 1; + return "".to_string(); + } + } + return format!("{}___s = true;", arg); + } + "".to_string() +}