From c1144e1823e0c1321191114bc716d7bad17b9513 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Tue, 13 Aug 2024 17:38:14 +0900 Subject: [PATCH 01/14] remove ___s check if possible --- src/ai/analysis.rs | 117 +++++++++++++++++++++++++++++++++++++++++--- src/bin/nopcrat.rs | 11 +++-- src/compile_util.rs | 7 +++ src/sampling.rs | 8 +-- src/transform.rs | 92 +++++++++++++++++++++++++++------- 5 files changed, 201 insertions(+), 34 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 24ac045..44639fc 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -14,7 +14,7 @@ use rustc_hir::{ use rustc_index::bit_set::BitSet; use rustc_middle::{ hir::nested_filter, - mir::{BasicBlock, Body, Local, Location, TerminatorKind}, + mir::{BasicBlock, Body, Local, Location, StatementKind, TerminatorKind}, ty::{AdtKind, Ty, TyCtxt, TyKind, TypeAndMut}, }; use rustc_session::config::Input; @@ -47,7 +47,10 @@ impl Default for AnalysisConfig { } } -pub type AnalysisResult = BTreeMap>; +pub type OutputParams = BTreeMap>; +pub type Wbrets = BTreeMap>; + +pub type AnalysisResult = (OutputParams, BTreeMap); pub fn analyze_path(path: &Path, conf: &AnalysisConfig) -> AnalysisResult { analyze_input(compile_util::path_to_input(path), conf) @@ -62,16 +65,19 @@ pub fn analyze_input(input: Input, conf: &AnalysisConfig) -> AnalysisResult { compile_util::run_compiler(config, |tcx| { analyze(tcx, conf) .into_iter() - .filter_map(|(def_id, (_, params))| { + .filter_map(|(def_id, (_, params, writes))| { if params.is_empty() { None } else { - Some((tcx.def_path_str(def_id), params)) + Some((tcx.def_path_str(def_id), (params, writes))) } }) - .collect() + .collect::>() }) .unwrap() + .into_iter() + .map(|(k, (v1, v2))| ((k.clone(), v1), (k, v2))) + .unzip() } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -84,7 +90,14 @@ enum Write { pub fn analyze( tcx: TyCtxt<'_>, conf: &AnalysisConfig, -) -> BTreeMap)> { +) -> BTreeMap< + DefId, + ( + FunctionSummary, + Vec, + Wbrets, + ), +> { let hir = tcx.hir(); let mut call_graph = BTreeMap::new(); @@ -150,6 +163,7 @@ pub fn analyze( let mut wm_map = BTreeMap::new(); let mut call_args_map = BTreeMap::new(); let mut analysis_times: BTreeMap<_, u128> = BTreeMap::new(); + let mut wbrets: BTreeMap>> = BTreeMap::new(); for id in &po { let def_ids = &elems[id]; let recursive = if def_ids.len() == 1 { @@ -199,7 +213,76 @@ pub fn analyze( .flat_map(|p| analyzer.expands_path(&AbsPath(vec![*p]))) .collect(); - let mut return_states = return_location(body) + let ret_location = return_location(body); + + let mut wbret = BTreeMap::new(); + + if let Some(ret_location) = ret_location { + if let Some(ret_loc_assign0) = exists_assign0(body, ret_location.block) { + let writes: BTreeSet<_> = states + .get(&ret_location) + .cloned() + .unwrap_or_default() + .values() + .flat_map(|st| st.writes.as_set()) + .map(|p| p.base() - 1) + .collect(); + + wbret.insert( + unsafe { std::mem::transmute(ret_loc_assign0.data()) }, + writes, + ); + } else { + let preds = body.basic_blocks.predecessors().get(ret_location.block); + + if let Some(v) = preds { + for i in v { + if let Some(sp) = exists_assign0(body, *i) { + let loc = Location { + block: *i, + statement_index: body.basic_blocks[*i].statements.len(), + }; + + let writes: BTreeSet<_> = states + .get(&loc) + .cloned() + .unwrap_or_default() + .values() + .flat_map(|st| st.writes.as_set()) + .map(|p| p.base() - 1) + .collect(); + + wbret.insert(unsafe { std::mem::transmute(sp.data()) }, writes); + } + } + } + } + } + + let wbret = if let Some(old) = wbrets.get(def_id) { + let spans: BTreeSet<_> = wbret.keys().chain(old.keys()).cloned().collect(); + + spans + .into_iter() + .map(|sp| { + ( + sp, + match (wbret.get(&sp), old.get(&sp)) { + (Some(v1), Some(v2)) => { + v1.intersection(v2).cloned().collect::>() + } + (Some(v), None) | (None, Some(v)) => (*v).clone(), + _ => unreachable!(), + }, + ) + }) + .collect::>() + } else { + wbret + }; + wbrets.insert(*def_id, wbret); + + let mut return_states = ret_location .and_then(|ret| states.get(&ret)) .cloned() .unwrap_or_default(); @@ -262,7 +345,14 @@ pub fn analyze( .into_iter() .map(|(def_id, summary)| { let output_params = output_params_map.remove(&def_id).unwrap(); - (def_id, (summary, output_params)) + ( + def_id, + ( + summary, + output_params, + wbrets.get(&def_id).cloned().unwrap_or_default(), + ), + ) }) .collect() } @@ -1062,6 +1152,17 @@ fn return_location(body: &Body<'_>) -> Option { None } +fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option { + for stmt in body.basic_blocks[bb].statements.iter() { + if let StatementKind::Assign(rb) = &stmt.kind { + if (**rb).0.local.as_u32() == 0u32 { + return Some(stmt.source_info.span); + } + } + } + None +} + fn get_rpo_map(body: &Body<'_>) -> BTreeMap { body.basic_blocks .reverse_postorder() diff --git a/src/bin/nopcrat.rs b/src/bin/nopcrat.rs index 468ae5a..df7a814 100644 --- a/src/bin/nopcrat.rs +++ b/src/bin/nopcrat.rs @@ -98,7 +98,7 @@ fn main() { }; if args.verbose { - for (func, params) in &analysis_result { + for (func, params) in &analysis_result.0 { println!("{}", func); for param in params { println!(" {:?}", param); @@ -107,7 +107,7 @@ fn main() { } if args.sample_negative { - let mut fns = sampling::sample_from_path(path, &analysis_result); + let mut fns = sampling::sample_from_path(path, &analysis_result.0); fns.shuffle(&mut thread_rng()); for f in fns.iter().take(10) { println!("{:?}", f); @@ -116,6 +116,7 @@ fn main() { } if args.sample_may || args.sample_must { let mut params: Vec<_> = analysis_result + .0 .iter() .filter(|(_, params)| params.iter().any(|p| p.must == args.sample_must)) .collect(); @@ -127,12 +128,14 @@ fn main() { } if args.use_analysis_result.is_none() { - let fns = analysis_result.len(); + let fns = analysis_result.0.len(); let musts = analysis_result + .0 .values() .map(|v| v.iter().filter(|p| p.must).count()) .sum::(); let mays = analysis_result + .0 .values() .map(|v| v.iter().filter(|p| !p.must).count()) .sum::(); @@ -148,7 +151,7 @@ fn main() { return; } - transform::transform_path(path, &analysis_result); + transform::transform_path(path, &analysis_result.0, &analysis_result.1); } fn clear_dir(path: &Path) { diff --git a/src/compile_util.rs b/src/compile_util.rs index 4560401..1178f05 100644 --- a/src/compile_util.rs +++ b/src/compile_util.rs @@ -99,6 +99,13 @@ pub fn span_to_path(span: Span, source_map: &SourceMap) -> Option { pub fn apply_suggestions>(suggestions: &BTreeMap>) { for (path, suggestions) in suggestions { let code = String::from_utf8(fs::read(path).unwrap()).unwrap(); + // for suggestion in suggestions { + // println!("{:?}", path.as_ref()); + // println!("{:?}", suggestion.snippets[0].line_range); + // println!("{:?}", suggestion.snippets[0].range); + // println!("{}", suggestion.solutions[0].replacements[0].replacement); + // println!(); + // } let fixed = rustfix::apply_suggestions(&code, suggestions).unwrap(); fs::write(path, fixed.as_bytes()).unwrap(); } diff --git a/src/sampling.rs b/src/sampling.rs index e0a6a41..ec882f4 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -8,17 +8,17 @@ use rustc_middle::{ }; use rustc_session::config::Input; -use crate::{ai::analysis::AnalysisResult, compile_util}; +use crate::{ai::analysis::OutputParams, compile_util}; -pub fn sample_from_path(path: &Path, res: &AnalysisResult) -> Vec { +pub fn sample_from_path(path: &Path, res: &OutputParams) -> Vec { sample_from_input(compile_util::path_to_input(path), res) } -pub fn sample_from_code(code: &str, res: &AnalysisResult) -> Vec { +pub fn sample_from_code(code: &str, res: &OutputParams) -> Vec { sample_from_input(compile_util::str_to_input(code), res) } -fn sample_from_input(input: Input, res: &AnalysisResult) -> Vec { +fn sample_from_input(input: Input, res: &OutputParams) -> Vec { let config = compile_util::make_config(input); compile_util::run_compiler(config, |tcx| { let hir = tcx.hir(); diff --git a/src/transform.rs b/src/transform.rs index d9394e4..fc57681 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -14,25 +14,48 @@ use rustc_middle::{ mir::{BasicBlock, TerminatorKind}, ty::TyCtxt, }; -use rustc_span::{def_id::DefId, source_map::SourceMap, BytePos, Span}; +use rustc_span::{def_id::DefId, source_map::SourceMap, BytePos, Span, SpanData}; use rustfix::Suggestion; use crate::{ai::analysis::*, compile_util}; -pub fn transform_path(path: &Path, params: &BTreeMap>) { +pub fn transform_path( + path: &Path, + params: &OutputParams, + writes: &BTreeMap, +) { let input = compile_util::path_to_input(path); let config = compile_util::make_config(input); - let suggestions = compile_util::run_compiler(config, |tcx| transform(tcx, params)).unwrap(); + let suggestions = + compile_util::run_compiler(config, |tcx| transform(tcx, params, writes)).unwrap(); compile_util::apply_suggestions(&suggestions); } fn transform( tcx: TyCtxt<'_>, - param_map: &BTreeMap>, + param_map: &OutputParams, + writes: &BTreeMap, ) -> BTreeMap> { let hir = tcx.hir(); let source_map = tcx.sess.source_map(); + let writes = writes + .iter() + .map(|(k, v)| { + ( + k, + v.iter() + .map(|(k, v)| { + ( + unsafe { std::mem::transmute::(*k) }.span(), + v, + ) + }) + .collect::>(), + ) + }) + .collect::>(); + let mut def_id_ty_map = BTreeMap::new(); for id in hir.items() { let item = hir.item(id); @@ -48,6 +71,7 @@ fn transform( } let mut funcs = BTreeMap::new(); + let mut wbrets = BTreeMap::new(); for id in hir.items() { let item = hir.item(id); let ItemKind::Fn(sig, _, body_id) = item.kind else { @@ -135,6 +159,24 @@ fn transform( (*index, param) }) .collect(); + if let Some(write) = writes.get(&name) { + wbrets.insert( + def_id, + write + .clone() + .into_iter() + .map(|(sp, params)| { + ( + sp, + params + .iter() + .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) + .collect::>(), + ) + }) + .collect::>(), + ); + } let hir_id_map: BTreeMap<_, _> = index_map .values() .cloned() @@ -338,7 +380,11 @@ fn transform( let post_span = post_span.with_hi(post_span.hi() + BytePos(1)); let rv = format!("{}rv___{}", pre_s, post_s); - let rv = func.return_value(Some(rv)); + let (_, wbret) = wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(**sp)) + .unwrap(); + let rv = func.return_value(Some(rv), wbret); fix( post_span, format!("; return {};{}", rv, if arm { " }" } else { "" }), @@ -397,16 +443,16 @@ fn transform( } else if passes.contains(¶m.name) { format!( " - let mut {0}___s: bool = false; \ - let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ - let mut {0}: *mut {1} = &mut {0}___v;", + let mut {0}___s: bool = false; \ + let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ + let mut {0}: *mut {1} = &mut {0}___v;", param.name, param.ty, ) } else { format!( " - let mut {0}___s: bool = false; \ - let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", + let mut {0}___s: bool = false; \ + let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", param.name, param.ty, ) } @@ -463,7 +509,7 @@ fn transform( let post_s = source_map.span_to_snippet(*s).unwrap(); rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); } - let rv = func.return_value(Some(rv)); + let rv = func.return_value(Some(rv), &[]); fix(*span, format!("return {}", rv)); } @@ -474,14 +520,18 @@ fn transform( continue; } let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); - let ret_v = func.return_value(orig); + let (_, wbret) = wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(**sp)) + .unwrap(); + let ret_v = func.return_value(orig, wbret); fix(span, format!("return {}", ret_v)); } if func.is_unit { let pos = body.value.span.hi() - BytePos(1); let span = body.value.span.with_lo(pos).with_hi(pos); - let ret_v = func.return_value(None); + let ret_v = func.return_value(None, &[]); fix(span, ret_v); } } @@ -655,15 +705,19 @@ impl Func { } } - fn return_value(&self, orig: Option) -> String { + fn return_value(&self, orig: Option, wbret: &[String]) -> String { let mut values = vec![]; if let Some((_, i)) = &self.first_return { let orig = orig.unwrap(); let param = &self.index_map[i]; - let v = format!( - "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", - param.name, orig - ); + let v = if wbret.contains(¶m.name) { + format!("Ok({}___v)", param.name) + } else { + format!( + "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", + param.name, orig + ) + }; values.push(v); } else if let Some(v) = orig { values.push(v); @@ -672,6 +726,8 @@ impl Func { let param = &self.index_map[i]; let v = if param.must { format!("{}___v", param.name) + } else if wbret.contains(¶m.name) { + format!("Some ({}___v)", param.name) } else { format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) }; From ff39e653bfdf20b39084cf1dfa5d9ad8df20826d Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Tue, 20 Aug 2024 19:29:46 +0900 Subject: [PATCH 02/14] remove ___s assign if possible --- src/ai/analysis.rs | 142 ++++++++++++++++++++++++++++++++++++++------- src/transform.rs | 131 +++++++++++++++++++++++++---------------- 2 files changed, 203 insertions(+), 70 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 44639fc..cef3f44 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -48,9 +48,14 @@ impl Default for AnalysisConfig { } pub type OutputParams = BTreeMap>; + +// Write Before RETurn s pub type Wbrets = BTreeMap>; -pub type AnalysisResult = (OutputParams, BTreeMap); +// Removable checks for Write s +pub type Rcfws = BTreeMap>; + +pub type AnalysisResult = (OutputParams, BTreeMap); pub fn analyze_path(path: &Path, conf: &AnalysisConfig) -> AnalysisResult { analyze_input(compile_util::path_to_input(path), conf) @@ -65,18 +70,18 @@ pub fn analyze_input(input: Input, conf: &AnalysisConfig) -> AnalysisResult { compile_util::run_compiler(config, |tcx| { analyze(tcx, conf) .into_iter() - .filter_map(|(def_id, (_, params, writes))| { + .filter_map(|(def_id, (_, params, wbret, rcfw))| { if params.is_empty() { None } else { - Some((tcx.def_path_str(def_id), (params, writes))) + Some((tcx.def_path_str(def_id), (params, wbret, rcfw))) } }) .collect::>() }) .unwrap() .into_iter() - .map(|(k, (v1, v2))| ((k.clone(), v1), (k, v2))) + .map(|(k, (v1, v2, v3))| ((k.clone(), v1), (k.clone(), (v2, v3)))) .unzip() } @@ -90,14 +95,7 @@ enum Write { pub fn analyze( tcx: TyCtxt<'_>, conf: &AnalysisConfig, -) -> BTreeMap< - DefId, - ( - FunctionSummary, - Vec, - Wbrets, - ), -> { +) -> BTreeMap, Wbrets, Rcfws)> { let hir = tcx.hir(); let mut call_graph = BTreeMap::new(); @@ -164,6 +162,9 @@ pub fn analyze( let mut call_args_map = BTreeMap::new(); let mut analysis_times: BTreeMap<_, u128> = BTreeMap::new(); let mut wbrets: BTreeMap>> = BTreeMap::new(); + let mut wbbbrets: BTreeMap>> = BTreeMap::new(); + + let mut rcfws = BTreeMap::new(); for id in &po { let def_ids = &elems[id]; let recursive = if def_ids.len() == 1 { @@ -216,11 +217,18 @@ pub fn analyze( let ret_location = return_location(body); let mut wbret = BTreeMap::new(); + let mut wbbbret = BTreeMap::new(); if let Some(ret_location) = ret_location { - if let Some(ret_loc_assign0) = exists_assign0(body, ret_location.block) { + if let Some((ret_loc_assign0, index)) = exists_assign0(body, ret_location.block) + { + let loc = Location { + block: ret_location.block, + statement_index: index, + }; + let writes: BTreeSet<_> = states - .get(&ret_location) + .get(&loc) .cloned() .unwrap_or_default() .values() @@ -230,17 +238,18 @@ pub fn analyze( wbret.insert( unsafe { std::mem::transmute(ret_loc_assign0.data()) }, - writes, + writes.clone(), ); + wbbbret.insert(ret_location.block, writes); } else { let preds = body.basic_blocks.predecessors().get(ret_location.block); if let Some(v) = preds { for i in v { - if let Some(sp) = exists_assign0(body, *i) { + if let Some((sp, index)) = exists_assign0(body, *i) { let loc = Location { block: *i, - statement_index: body.basic_blocks[*i].statements.len(), + statement_index: index, }; let writes: BTreeSet<_> = states @@ -252,7 +261,11 @@ pub fn analyze( .map(|p| p.base() - 1) .collect(); - wbret.insert(unsafe { std::mem::transmute(sp.data()) }, writes); + wbret.insert( + unsafe { std::mem::transmute(sp.data()) }, + writes.clone(), + ); + wbbbret.insert(*i, writes); } } } @@ -280,7 +293,30 @@ pub fn analyze( } else { wbret }; + + let wbbbret = if let Some(old) = wbbbrets.get(def_id) { + let keys: BTreeSet<_> = wbbbret.keys().chain(old.keys()).cloned().collect(); + + keys.into_iter() + .map(|bb| { + ( + bb, + match (wbbbret.get(&bb), old.get(&bb)) { + (Some(v1), Some(v2)) => { + v1.intersection(v2).cloned().collect::>() + } + (Some(v), None) | (None, Some(v)) => (*v).clone(), + _ => unreachable!(), + }, + ) + }) + .collect::>() + } else { + wbbbret + }; + wbrets.insert(*def_id, wbret); + wbbbrets.insert(*def_id, wbbbret); let mut return_states = ret_location .and_then(|ret| states.get(&ret)) @@ -323,6 +359,69 @@ pub fn analyze( for p in &mut output_params { analyzer.find_complete_write(p, &result, &writes_map, &call_args, *def_id); } + + let body = tcx.optimized_mir(*def_id); + let wbbbret = &wbbbrets[def_id]; + let mut rcfw: Rcfws = BTreeMap::new(); + for p in output_params.iter() { + let OutputParam { + index, + must: _, + return_values: _, + complete_writes, + } = p; + for complete_write in complete_writes.iter() { + let CompleteWrite { + block, + statement_index, + write_arg: _, + } = complete_write; + + let mut stack = vec![BasicBlock::from_usize(*block)]; + + let success = loop { + if let Some(block) = stack.pop() { + match wbbbret.get(&block) { + Some(ws) => { + if !ws.contains(index) { + break false; + } + } + None => (), + } + + let bbd = &body.basic_blocks[block]; + let term = bbd.terminator(); + + match term.kind { + TerminatorKind::Return => (), + _ => { + for bb in term.successors() { + stack.push(bb); + } + } + } + } else { + break true; + } + }; + + if success { + let location = Location { + block: BasicBlock::from_usize(*block), + statement_index: *statement_index, + }; + let span = unsafe { + std::mem::transmute(body.source_info(location).span.data()) + }; + + let entry = rcfw.entry(*index); + entry.or_default().insert(span); + } + } + } + + rcfws.insert(*def_id, rcfw); output_params_map.insert(*def_id, output_params); } break; @@ -351,6 +450,7 @@ pub fn analyze( summary, output_params, wbrets.get(&def_id).cloned().unwrap_or_default(), + rcfws.get(&def_id).cloned().unwrap_or_default(), ), ) }) @@ -1152,11 +1252,11 @@ fn return_location(body: &Body<'_>) -> Option { None } -fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option { - for stmt in body.basic_blocks[bb].statements.iter() { +fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option<(Span, usize)> { + for (i, stmt) in body.basic_blocks[bb].statements.iter().enumerate() { if let StatementKind::Assign(rb) = &stmt.kind { if (**rb).0.local.as_u32() == 0u32 { - return Some(stmt.source_info.span); + return Some((stmt.source_info.span, i)); } } } diff --git a/src/transform.rs b/src/transform.rs index fc57681..fc1b36f 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -22,40 +22,23 @@ use crate::{ai::analysis::*, compile_util}; pub fn transform_path( path: &Path, params: &OutputParams, - writes: &BTreeMap, + extra_info: &BTreeMap, ) { let input = compile_util::path_to_input(path); let config = compile_util::make_config(input); let suggestions = - compile_util::run_compiler(config, |tcx| transform(tcx, params, writes)).unwrap(); + compile_util::run_compiler(config, |tcx| transform(tcx, params, extra_info)).unwrap(); compile_util::apply_suggestions(&suggestions); } fn transform( tcx: TyCtxt<'_>, param_map: &OutputParams, - writes: &BTreeMap, + extra_info: &BTreeMap, ) -> BTreeMap> { let hir = tcx.hir(); let source_map = tcx.sess.source_map(); - let writes = writes - .iter() - .map(|(k, v)| { - ( - k, - v.iter() - .map(|(k, v)| { - ( - unsafe { std::mem::transmute::(*k) }.span(), - v, - ) - }) - .collect::>(), - ) - }) - .collect::>(); - let mut def_id_ty_map = BTreeMap::new(); for id in hir.items() { let item = hir.item(id); @@ -72,6 +55,7 @@ fn transform( let mut funcs = BTreeMap::new(); let mut wbrets = BTreeMap::new(); + let mut rcfws = BTreeMap::new(); for id in hir.items() { let item = hir.item(id); let ItemKind::Fn(sig, _, body_id) = item.kind else { @@ -159,24 +143,46 @@ fn transform( (*index, param) }) .collect(); - if let Some(write) = writes.get(&name) { + + if let Some((wbret, _)) = extra_info.get(&name) { wbrets.insert( def_id, - write + wbret .clone() .into_iter() .map(|(sp, params)| { ( - sp, + unsafe { std::mem::transmute::(sp) }.span(), params .iter() .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) .collect::>(), ) }) - .collect::>(), + .collect::>(), + ); + } + + if let Some((_, rcfw)) = extra_info.get(&name) { + rcfws.insert( + def_id, + rcfw.clone() + .into_iter() + .map(|(index, spans)| { + ( + index_map.get(&index).cloned().unwrap().name, + spans + .iter() + .map(|sp| { + unsafe { std::mem::transmute::(*sp) }.span() + }) + .collect(), + ) + }) + .collect::>>(), ); } + let hir_id_map: BTreeMap<_, _> = index_map .values() .cloned() @@ -298,7 +304,11 @@ fn transform( } let assign_map = curr.map(|c| c.assign_map(span)).unwrap_or_default(); - let mut mtch = func.call_match(&args, &assign_map); + + let mut mtch = func.first_return.and_then(|(_, first)| { + let set_flag = generate_set_flag(&span, &first, &rcfws[&def_id], &assign_map); + func.call_match(&args, set_flag) + }); if let Some(call) = get_if_cmp_call(hir_id, span, tcx) { if let Some(then) = func.cmp(call.op, call.target) { @@ -310,11 +320,7 @@ fn transform( let fail = "Err(_) => "; let (_, i) = func.first_return.as_ref().unwrap(); let arg = &args[*i]; - let set_flag = if let Some(arg) = assign_map.get(i) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; + let set_flag = generate_set_flag(&span, i, &rcfws[&def_id], &assign_map); let assign = if arg.code.contains("&mut ") { format!(" *({}) = v___; {}", arg.code, set_flag) } else { @@ -382,7 +388,7 @@ fn transform( let rv = format!("{}rv___{}", pre_s, post_s); let (_, wbret) = wbrets[&def_id] .iter() - .find(|(sp, _)| span.contains(**sp)) + .find(|(sp, _)| span.contains(*sp)) .unwrap(); let rv = func.return_value(Some(rv), wbret); fix( @@ -398,7 +404,18 @@ fn transform( } fix(span.shrink_to_lo(), binding); - let mut assign = func.call_assign(&args, &assign_map); + let set_flags = func + .remaining_return + .iter() + .map(|i| { + ( + *i, + generate_set_flag(&span, i, &rcfws[&def_id], &assign_map), + ) + }) + .collect(); + + let mut assign = func.call_assign(&args, &set_flags); if let Some(m) = &mtch { assign += m; assign += ")"; @@ -464,10 +481,18 @@ fn transform( fix(span, local_vars); for param in func.params() { + let rcfw = &rcfws[&def_id].get(¶m.name); + for span in ¶m.writes { if call_spans.contains(span) { continue; } + if let Some(rcfw) = rcfw { + if rcfw.iter().any(|sp| span.contains(*sp)) { + continue; + } + } + let pos = span.hi() + BytePos(1); let span = span.with_hi(pos).with_lo(pos); let assign = format!("{0}___s = true;", param.name); @@ -522,7 +547,7 @@ fn transform( let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); let (_, wbret) = wbrets[&def_id] .iter() - .find(|(sp, _)| span.contains(**sp)) + .find(|(sp, _)| span.contains(*sp)) .unwrap(); let ret_v = func.return_value(orig, wbret); fix(span, format!("return {}", ret_v)); @@ -616,34 +641,29 @@ impl Func { map } - fn call_assign(&self, args: &[Arg], assign_map: &BTreeMap) -> String { + fn call_assign(&self, args: &[Arg], set_flags: &BTreeMap) -> String { let mut assigns = vec![]; for i in &self.remaining_return { let arg = &args[*i]; let param = &self.index_map[i]; - let set_flag = if let Some(arg) = assign_map.get(i) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; let assign = if param.must { if arg.code.contains("&mut ") { - format!("*({}) = rv___{}; {}", arg.code, i, set_flag) + format!("*({}) = rv___{}; {}", arg.code, i, set_flags[i]) } else { format!( "if !({0}).is_null() {{ *({0}) = rv___{1}; {2} }}", - arg.code, i, set_flag + arg.code, i, set_flags[i] ) } } else if arg.code.contains("&mut ") { format!( "if let Some(v___) = rv___{} {{ *({}) = v___; {} }}", - i, arg.code, set_flag + i, arg.code, set_flags[i] ) } else { format!( "if !({0}).is_null() {{ if let Some(v___) = rv___{1} {{ *({0}) = v___; {2} }} }}", - arg.code, i, set_flag + arg.code, i, set_flags[i] ) }; assigns.push(assign); @@ -652,14 +672,9 @@ impl Func { mk_string(assigns.iter(), "; ", " ", end) } - fn call_match(&self, args: &[Arg], assign_map: &BTreeMap) -> Option { + fn call_match(&self, args: &[Arg], set_flag: String) -> Option { let (succ_value, first) = &self.first_return?; let arg = &args[*first]; - let set_flag = if let Some(arg) = assign_map.get(first) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; let assign = if arg.code.contains("&mut ") { format!("*({}) = v___; {}", arg.code, set_flag) } else { @@ -1070,3 +1085,21 @@ fn mk_string, I: Iterator>( s.push_str(end); s } + +fn generate_set_flag( + span: &Span, + i: &usize, + rcfws: &BTreeMap>, + assign_map: &BTreeMap, +) -> String { + if let Some(arg) = assign_map.get(i) { + let rcfw = &rcfws.get(arg); + if let Some(rcfw) = rcfw { + if rcfw.iter().any(|sp| span.contains(*sp)) { + return "".to_string(); + } + } + return format!("{}___s = true;", arg); + } + "".to_string() +} From 36284b31783fb062aff2fd188924c615c121e5d9 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Thu, 22 Aug 2024 14:08:27 +0900 Subject: [PATCH 03/14] fixed bug in previous commit and added more cases for ___s removal --- src/ai/analysis.rs | 148 +++++++++++++++++++-------------------------- src/transform.rs | 89 ++++++++++++++++----------- 2 files changed, 115 insertions(+), 122 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index cef3f44..85b9aee 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -18,7 +18,7 @@ use rustc_middle::{ ty::{AdtKind, Ty, TyCtxt, TyKind, TypeAndMut}, }; use rustc_session::config::Input; -use rustc_span::{def_id::DefId, source_map::SourceMap, Span}; +use rustc_span::{def_id::DefId, source_map::SourceMap, BytePos, Span}; use serde::{Deserialize, Serialize}; use super::{domains::*, semantics::TransferedTerminator}; @@ -50,7 +50,7 @@ impl Default for AnalysisConfig { pub type OutputParams = BTreeMap>; // Write Before RETurn s -pub type Wbrets = BTreeMap>; +pub type Wbrets = BTreeMap, BTreeSet)>; // Removable checks for Write s pub type Rcfws = BTreeMap>; @@ -161,7 +161,8 @@ pub fn analyze( let mut wm_map = BTreeMap::new(); let mut call_args_map = BTreeMap::new(); let mut analysis_times: BTreeMap<_, u128> = BTreeMap::new(); - let mut wbrets: BTreeMap>> = BTreeMap::new(); + + let mut wbrets: BTreeMap = BTreeMap::new(); let mut wbbbrets: BTreeMap>> = BTreeMap::new(); let mut rcfws = BTreeMap::new(); @@ -220,100 +221,76 @@ pub fn analyze( let mut wbbbret = BTreeMap::new(); if let Some(ret_location) = ret_location { + let mut stack = vec![]; + if let Some((ret_loc_assign0, index)) = exists_assign0(body, ret_location.block) { - let loc = Location { - block: ret_location.block, - statement_index: index, - }; - - let writes: BTreeSet<_> = states - .get(&loc) - .cloned() - .unwrap_or_default() - .values() - .flat_map(|st| st.writes.as_set()) - .map(|p| p.base() - 1) - .collect(); - - wbret.insert( - unsafe { std::mem::transmute(ret_loc_assign0.data()) }, - writes.clone(), - ); - wbbbret.insert(ret_location.block, writes); + stack.push(( + Location { + block: ret_location.block, + statement_index: index, + }, + ret_loc_assign0.data(), + )); } else { let preds = body.basic_blocks.predecessors().get(ret_location.block); - if let Some(v) = preds { for i in v { if let Some((sp, index)) = exists_assign0(body, *i) { - let loc = Location { - block: *i, - statement_index: index, - }; - - let writes: BTreeSet<_> = states - .get(&loc) - .cloned() - .unwrap_or_default() - .values() - .flat_map(|st| st.writes.as_set()) - .map(|p| p.base() - 1) - .collect(); - - wbret.insert( - unsafe { std::mem::transmute(sp.data()) }, - writes.clone(), - ); - wbbbret.insert(*i, writes); + stack.push(( + Location { + block: *i, + statement_index: index, + }, + sp.data(), + )); } } } } - } - let wbret = if let Some(old) = wbrets.get(def_id) { - let spans: BTreeSet<_> = wbret.keys().chain(old.keys()).cloned().collect(); + if stack.is_empty() { + let span = body.source_info(ret_location).span; + let pos = span.lo() - BytePos(1); + stack.push((ret_location, span.with_lo(pos).with_hi(pos).data())); + } - spans - .into_iter() - .map(|sp| { - ( - sp, - match (wbret.get(&sp), old.get(&sp)) { - (Some(v1), Some(v2)) => { - v1.intersection(v2).cloned().collect::>() + for (loc, sp) in stack.iter() { + let must_writes: BTreeSet<_> = states + .get(loc) + .cloned() + .unwrap_or_default() + .values() + .fold(None, |acc: Option>, st: &AbsState| { + Some(match acc { + Some(acc) => { + acc.intersection(st.writes.as_set()).cloned().collect() } - (Some(v), None) | (None, Some(v)) => (*v).clone(), - _ => unreachable!(), - }, - ) - }) - .collect::>() - } else { - wbret - }; + None => st.writes.as_set().clone(), + }) + }) + .unwrap_or_default() + .iter() + .map(|p| p.base() - 1) + .collect(); - let wbbbret = if let Some(old) = wbbbrets.get(def_id) { - let keys: BTreeSet<_> = wbbbret.keys().chain(old.keys()).cloned().collect(); + let may_writes: BTreeSet<_> = states + .get(loc) + .cloned() + .unwrap_or_default() + .values() + .flat_map(|st| st.writes.as_set()) + .map(|p| p.base() - 1) + .collect(); - keys.into_iter() - .map(|bb| { - ( - bb, - match (wbbbret.get(&bb), old.get(&bb)) { - (Some(v1), Some(v2)) => { - v1.intersection(v2).cloned().collect::>() - } - (Some(v), None) | (None, Some(v)) => (*v).clone(), - _ => unreachable!(), - }, - ) - }) - .collect::>() - } else { - wbbbret - }; + wbret.insert( + unsafe { std::mem::transmute(*sp) }, + (may_writes, must_writes.clone()), + ); + + wbbbret.insert(loc.block, must_writes); + } + } wbrets.insert(*def_id, wbret); wbbbrets.insert(*def_id, wbbbret); @@ -381,13 +358,10 @@ pub fn analyze( let success = loop { if let Some(block) = stack.pop() { - match wbbbret.get(&block) { - Some(ws) => { - if !ws.contains(index) { - break false; - } + if let Some(ws) = wbbbret.get(&block) { + if !ws.contains(index) { + break false; } - None => (), } let bbd = &body.basic_blocks[block]; diff --git a/src/transform.rs b/src/transform.rs index fc1b36f..81fbeaa 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -150,13 +150,17 @@ fn transform( wbret .clone() .into_iter() - .map(|(sp, params)| { + .map(|(sp, (may, must))| { ( unsafe { std::mem::transmute::(sp) }.span(), - params - .iter() - .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) - .collect::>(), + ( + may.iter() + .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) + .collect::>(), + must.iter() + .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) + .collect::>(), + ), ) }) .collect::>(), @@ -386,11 +390,11 @@ fn transform( let post_span = post_span.with_hi(post_span.hi() + BytePos(1)); let rv = format!("{}rv___{}", pre_s, post_s); - let (_, wbret) = wbrets[&def_id] + let (_, (may, must)) = wbrets[&def_id] .iter() .find(|(sp, _)| span.contains(*sp)) .unwrap(); - let rv = func.return_value(Some(rv), wbret); + let rv = func.return_value(Some(rv), may, must); fix( post_span, format!("; return {};{}", rv, if arm { " }" } else { "" }), @@ -439,10 +443,33 @@ fn transform( let ret_ty = func.return_type(orig); fix(span, format!("-> {}", ret_ty)); + let mut unremovable = BTreeSet::new(); + for param in func.params() { + let rcfw = &rcfws[&def_id].get(¶m.name); + + for span in ¶m.writes { + if call_spans.contains(span) { + continue; + } + if let Some(rcfw) = rcfw { + if rcfw.iter().any(|sp| span.contains(*sp)) { + continue; + } + } + + unremovable.insert(¶m.name); + + let pos = span.hi() + BytePos(1); + let span = span.with_hi(pos).with_lo(pos); + let assign = format!("{0}___s = true;", param.name); + fix(span, assign); + } + } + let local_vars: String = func .params() .map(|param| { - if param.must { + if param.must || !unremovable.contains(¶m.name) { if passes.contains(¶m.name) { format!( " @@ -480,26 +507,6 @@ fn transform( let span = body.value.span.with_lo(pos).with_hi(pos); fix(span, local_vars); - for param in func.params() { - let rcfw = &rcfws[&def_id].get(¶m.name); - - for span in ¶m.writes { - if call_spans.contains(span) { - continue; - } - if let Some(rcfw) = rcfw { - if rcfw.iter().any(|sp| span.contains(*sp)) { - continue; - } - } - - let pos = span.hi() + BytePos(1); - let span = span.with_hi(pos).with_lo(pos); - let assign = format!("{0}___s = true;", param.name); - fix(span, assign); - } - } - for param in func.params() { if let Some(spans) = ref_to_spans.get(¶m.name) { for span in spans { @@ -534,7 +541,11 @@ fn transform( let post_s = source_map.span_to_snippet(*s).unwrap(); rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); } - let rv = func.return_value(Some(rv), &[]); + let (_, (may, must)) = wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .unwrap(); + let rv = func.return_value(Some(rv), may, must); fix(*span, format!("return {}", rv)); } @@ -545,18 +556,22 @@ fn transform( continue; } let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); - let (_, wbret) = wbrets[&def_id] + let (_, (may, must)) = wbrets[&def_id] .iter() .find(|(sp, _)| span.contains(*sp)) .unwrap(); - let ret_v = func.return_value(orig, wbret); + let ret_v = func.return_value(orig, may, must); fix(span, format!("return {}", ret_v)); } if func.is_unit { let pos = body.value.span.hi() - BytePos(1); let span = body.value.span.with_lo(pos).with_hi(pos); - let ret_v = func.return_value(None, &[]); + let (_, (may, must)) = wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .unwrap(); + let ret_v = func.return_value(None, may, must); fix(span, ret_v); } } @@ -720,13 +735,15 @@ impl Func { } } - fn return_value(&self, orig: Option, wbret: &[String]) -> String { + fn return_value(&self, orig: Option, may: &[String], must: &[String]) -> String { let mut values = vec![]; if let Some((_, i)) = &self.first_return { let orig = orig.unwrap(); let param = &self.index_map[i]; - let v = if wbret.contains(¶m.name) { + let v = if must.contains(¶m.name) { format!("Ok({}___v)", param.name) + } else if !may.contains(¶m.name) { + format!("Err({})", orig) } else { format!( "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", @@ -741,8 +758,10 @@ impl Func { let param = &self.index_map[i]; let v = if param.must { format!("{}___v", param.name) - } else if wbret.contains(¶m.name) { + } else if must.contains(¶m.name) { format!("Some ({}___v)", param.name) + } else if !may.contains(¶m.name) { + "None".to_string() } else { format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) }; From c00f77e307b5b3c7ac3644a0a2f941a662d0c9bf Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Thu, 19 Sep 2024 17:42:08 +0900 Subject: [PATCH 04/14] handle recursive case in rcfw --- src/ai/analysis.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 85b9aee..81beca1 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -355,10 +355,11 @@ pub fn analyze( } = complete_write; let mut stack = vec![BasicBlock::from_usize(*block)]; + let mut visited = BTreeSet::new(); let success = loop { if let Some(block) = stack.pop() { - if let Some(ws) = wbbbret.get(&block) { + if let Some(ws) = wbbbret.get(&block) { if !ws.contains(index) { break false; } @@ -367,11 +368,15 @@ pub fn analyze( let bbd = &body.basic_blocks[block]; let term = bbd.terminator(); + visited.insert(block); + match term.kind { TerminatorKind::Return => (), _ => { for bb in term.successors() { - stack.push(bb); + if !visited.contains(&bb) { + stack.push(bb); + } } } } From 392e46d1541104cb7c7c3e99bfcb8aa0bbc626d7 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Tue, 24 Sep 2024 12:35:49 +0900 Subject: [PATCH 05/14] removed simplification for unit functions and bug fixes --- src/ai/analysis.rs | 136 ++++++++++++++++++++++----------------------- src/transform.rs | 121 +++++++++++++++++++++++++--------------- 2 files changed, 142 insertions(+), 115 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 81beca1..b95b8e2 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -18,7 +18,7 @@ use rustc_middle::{ ty::{AdtKind, Ty, TyCtxt, TyKind, TypeAndMut}, }; use rustc_session::config::Input; -use rustc_span::{def_id::DefId, source_map::SourceMap, BytePos, Span}; +use rustc_span::{def_id::DefId, source_map::SourceMap, Span}; use serde::{Deserialize, Serialize}; use super::{domains::*, semantics::TransferedTerminator}; @@ -164,6 +164,7 @@ pub fn analyze( let mut wbrets: BTreeMap = BTreeMap::new(); let mut wbbbrets: BTreeMap>> = BTreeMap::new(); + let mut is_units = BTreeMap::new(); let mut rcfws = BTreeMap::new(); for id in &po { @@ -220,9 +221,9 @@ pub fn analyze( let mut wbret = BTreeMap::new(); let mut wbbbret = BTreeMap::new(); - if let Some(ret_location) = ret_location { - let mut stack = vec![]; + let mut stack = vec![]; + if let Some(ret_location) = ret_location { if let Some((ret_loc_assign0, index)) = exists_assign0(body, ret_location.block) { stack.push(( @@ -232,29 +233,21 @@ pub fn analyze( }, ret_loc_assign0.data(), )); - } else { - let preds = body.basic_blocks.predecessors().get(ret_location.block); - if let Some(v) = preds { - for i in v { - if let Some((sp, index)) = exists_assign0(body, *i) { - stack.push(( - Location { - block: *i, - statement_index: index, - }, - sp.data(), - )); - } + } else if let Some(v) = body.basic_blocks.predecessors().get(ret_location.block) + { + for i in v { + if let Some((sp, index)) = exists_assign0(body, *i) { + stack.push(( + Location { + block: *i, + statement_index: index, + }, + sp.data(), + )); } } } - if stack.is_empty() { - let span = body.source_info(ret_location).span; - let pos = span.lo() - BytePos(1); - stack.push((ret_location, span.with_lo(pos).with_hi(pos).data())); - } - for (loc, sp) in stack.iter() { let must_writes: BTreeSet<_> = states .get(loc) @@ -294,6 +287,7 @@ pub fn analyze( wbrets.insert(*def_id, wbret); wbbbrets.insert(*def_id, wbbbret); + is_units.insert(*def_id, stack.is_empty()); let mut return_states = ret_location .and_then(|ret| states.get(&ret)) @@ -340,62 +334,64 @@ pub fn analyze( let body = tcx.optimized_mir(*def_id); let wbbbret = &wbbbrets[def_id]; let mut rcfw: Rcfws = BTreeMap::new(); - for p in output_params.iter() { - let OutputParam { - index, - must: _, - return_values: _, - complete_writes, - } = p; - for complete_write in complete_writes.iter() { - let CompleteWrite { - block, - statement_index, - write_arg: _, - } = complete_write; - - let mut stack = vec![BasicBlock::from_usize(*block)]; - let mut visited = BTreeSet::new(); - - let success = loop { - if let Some(block) = stack.pop() { - if let Some(ws) = wbbbret.get(&block) { - if !ws.contains(index) { - break false; - } - } - let bbd = &body.basic_blocks[block]; - let term = bbd.terminator(); - - visited.insert(block); + if !is_units[def_id] { + for p in output_params.iter() { + let OutputParam { + index, + must: _, + return_values: _, + complete_writes, + } = p; + for complete_write in complete_writes.iter() { + let CompleteWrite { + block, + statement_index, + write_arg: _, + } = complete_write; + + let mut stack = vec![BasicBlock::from_usize(*block)]; + let mut visited = BTreeSet::from_iter(stack.clone()); + + let success = loop { + if let Some(block) = stack.pop() { + if let Some(ws) = wbbbret.get(&block) { + if !ws.contains(index) { + break false; + } + } - match term.kind { - TerminatorKind::Return => (), - _ => { - for bb in term.successors() { - if !visited.contains(&bb) { - stack.push(bb); + let bbd = &body.basic_blocks[block]; + let term = bbd.terminator(); + + match term.kind { + TerminatorKind::Return => (), + _ => { + for bb in term.successors() { + if !visited.contains(&bb) { + visited.insert(bb); + stack.push(bb); + } } } } + } else { + break true; } - } else { - break true; - } - }; - - if success { - let location = Location { - block: BasicBlock::from_usize(*block), - statement_index: *statement_index, - }; - let span = unsafe { - std::mem::transmute(body.source_info(location).span.data()) }; - let entry = rcfw.entry(*index); - entry.or_default().insert(span); + if success { + let location = Location { + block: BasicBlock::from_usize(*block), + statement_index: *statement_index, + }; + let span = unsafe { + std::mem::transmute(body.source_info(location).span.data()) + }; + + let entry = rcfw.entry(*index); + entry.or_default().insert(span); + } } } } diff --git a/src/transform.rs b/src/transform.rs index 81fbeaa..431a28f 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -227,6 +227,9 @@ fn transform( let body = hir.body(body_id); let curr = funcs.get(&def_id); + let default = BTreeMap::new(); + let rcfw = rcfws.get(&def_id).unwrap_or(&default); + let mut visitor = BodyVisitor::new(tcx); visitor.visit_body(body); @@ -310,7 +313,7 @@ fn transform( let assign_map = curr.map(|c| c.assign_map(span)).unwrap_or_default(); let mut mtch = func.first_return.and_then(|(_, first)| { - let set_flag = generate_set_flag(&span, &first, &rcfws[&def_id], &assign_map); + let set_flag = generate_set_flag(&span, &first, rcfw, &assign_map); func.call_match(&args, set_flag) }); @@ -324,7 +327,7 @@ fn transform( let fail = "Err(_) => "; let (_, i) = func.first_return.as_ref().unwrap(); let arg = &args[*i]; - let set_flag = generate_set_flag(&span, i, &rcfws[&def_id], &assign_map); + let set_flag = generate_set_flag(&span, i, rcfw, &assign_map); let assign = if arg.code.contains("&mut ") { format!(" *({}) = v___; {}", arg.code, set_flag) } else { @@ -390,11 +393,13 @@ fn transform( let post_span = post_span.with_hi(post_span.hi() + BytePos(1)); let rv = format!("{}rv___{}", pre_s, post_s); - let (_, (may, must)) = wbrets[&def_id] - .iter() - .find(|(sp, _)| span.contains(*sp)) - .unwrap(); - let rv = func.return_value(Some(rv), may, must); + let rv = if let Some((_, wbret)) = + wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) + { + func.return_value(Some(rv), Some(wbret)) + } else { + func.return_value(Some(rv), None) + }; fix( post_span, format!("; return {};{}", rv, if arm { " }" } else { "" }), @@ -411,12 +416,7 @@ fn transform( let set_flags = func .remaining_return .iter() - .map(|i| { - ( - *i, - generate_set_flag(&span, i, &rcfws[&def_id], &assign_map), - ) - }) + .map(|i| (*i, generate_set_flag(&span, i, rcfw, &assign_map))) .collect(); let mut assign = func.call_assign(&args, &set_flags); @@ -445,12 +445,9 @@ fn transform( let mut unremovable = BTreeSet::new(); for param in func.params() { - let rcfw = &rcfws[&def_id].get(¶m.name); + let rcfw = &rcfw.get(¶m.name); for span in ¶m.writes { - if call_spans.contains(span) { - continue; - } if let Some(rcfw) = rcfw { if rcfw.iter().any(|sp| span.contains(*sp)) { continue; @@ -459,6 +456,10 @@ fn transform( unremovable.insert(¶m.name); + if call_spans.contains(span) { + continue; + } + let pos = span.hi() + BytePos(1); let span = span.with_hi(pos).with_lo(pos); let assign = format!("{0}___s = true;", param.name); @@ -469,7 +470,7 @@ fn transform( let local_vars: String = func .params() .map(|param| { - if param.must || !unremovable.contains(¶m.name) { + if param.must || (!unremovable.contains(¶m.name)) { if passes.contains(¶m.name) { format!( " @@ -541,38 +542,53 @@ fn transform( let post_s = source_map.span_to_snippet(*s).unwrap(); rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); } - let (_, (may, must)) = wbrets[&def_id] - .iter() - .find(|(sp, _)| span.contains(*sp)) - .unwrap(); - let rv = func.return_value(Some(rv), may, must); + let rv = if let Some((_, wbret)) = + wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) + { + func.return_value(Some(rv), Some(wbret)) + } else { + func.return_value(Some(rv), None) + }; fix(*span, format!("return {}", rv)); } - for ret in visitor.returns { + for ret in visitor.returns.clone() { let Return { span, value } = ret; if ret_call_spans.contains(&span) || ret_to_ref_spans.contains_key(&span) { continue; } let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); - let (_, (may, must)) = wbrets[&def_id] - .iter() - .find(|(sp, _)| span.contains(*sp)) - .unwrap(); - let ret_v = func.return_value(orig, may, must); + let ret_v = if let Some((_, wbret)) = + wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) + { + func.return_value(orig, Some(wbret)) + } else { + func.return_value(orig, None) + }; fix(span, format!("return {}", ret_v)); } if func.is_unit { let pos = body.value.span.hi() - BytePos(1); let span = body.value.span.with_lo(pos).with_hi(pos); - let (_, (may, must)) = wbrets[&def_id] - .iter() - .find(|(sp, _)| span.contains(*sp)) - .unwrap(); - let ret_v = func.return_value(None, may, must); - fix(span, ret_v); + let pos = pos - BytePos(3); + let prev = span.with_lo(pos).with_hi(pos); + + let mut skip = false; + + for ret in visitor.returns { + let Return { span, value: _ } = ret; + if span.overlaps(prev) { + skip = true; + break; + } + } + + if !skip { + let ret_v = func.return_value(None, None); + fix(span, ret_v); + } } } suggestions.retain(|_, v| !v.is_empty()); @@ -735,15 +751,26 @@ impl Func { } } - fn return_value(&self, orig: Option, may: &[String], must: &[String]) -> String { + fn return_value( + &self, + orig: Option, + wbret: Option<&(Vec, Vec)>, + ) -> String { let mut values = vec![]; if let Some((_, i)) = &self.first_return { let orig = orig.unwrap(); let param = &self.index_map[i]; - let v = if must.contains(¶m.name) { - format!("Ok({}___v)", param.name) - } else if !may.contains(¶m.name) { - format!("Err({})", orig) + let v = if let Some((may, must)) = wbret { + if must.contains(¶m.name) { + format!("Ok({}___v)", param.name) + } else if !may.contains(¶m.name) { + format!("Err({})", orig) + } else { + format!( + "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", + param.name, orig + ) + } } else { format!( "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", @@ -758,10 +785,14 @@ impl Func { let param = &self.index_map[i]; let v = if param.must { format!("{}___v", param.name) - } else if must.contains(¶m.name) { - format!("Some ({}___v)", param.name) - } else if !may.contains(¶m.name) { - "None".to_string() + } else if let Some((may, must)) = wbret { + if must.contains(¶m.name) { + format!("Some ({}___v)", param.name) + } else if !may.contains(¶m.name) { + "None".to_string() + } else { + format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) + } } else { format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) }; @@ -775,7 +806,7 @@ impl Func { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Return { span: Span, value: Option, From 913778723ae3476f9ebda3b0d3b5a2b99bf0eb7b Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Wed, 25 Sep 2024 00:23:36 +0900 Subject: [PATCH 06/14] Added terminator to return location check --- src/ai/analysis.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index b95b8e2..5007fa5 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -1235,6 +1235,12 @@ fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option<(Span, usize)> { } } } + let term = body.basic_blocks[bb].terminator(); + if let TerminatorKind::Call { func: _, args: _, destination, target: _, unwind: _, call_source: _, fn_span: _ } = term.kind { + if destination.local.as_u32() == 0u32 { + return Some((term.source_info.span, body.basic_blocks[bb].statements.len())); + } + } None } From ddc1706eb9912a9f94d879bac8ab36c8f0b472e5 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Wed, 25 Sep 2024 03:55:14 +0900 Subject: [PATCH 07/14] fixed bug in terminator return location --- src/ai/analysis.rs | 50 +++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 5007fa5..330c1d7 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -224,26 +224,15 @@ pub fn analyze( let mut stack = vec![]; if let Some(ret_location) = ret_location { - if let Some((ret_loc_assign0, index)) = exists_assign0(body, ret_location.block) + if let Some((ret_loc_assign0, ret_loc)) = + exists_assign0(body, ret_location.block) { - stack.push(( - Location { - block: ret_location.block, - statement_index: index, - }, - ret_loc_assign0.data(), - )); + stack.push((ret_loc, ret_loc_assign0.data())); } else if let Some(v) = body.basic_blocks.predecessors().get(ret_location.block) { for i in v { - if let Some((sp, index)) = exists_assign0(body, *i) { - stack.push(( - Location { - block: *i, - statement_index: index, - }, - sp.data(), - )); + if let Some((sp, ret_loc)) = exists_assign0(body, *i) { + stack.push((ret_loc, sp.data())); } } } @@ -1227,18 +1216,39 @@ fn return_location(body: &Body<'_>) -> Option { None } -fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option<(Span, usize)> { +fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option<(Span, Location)> { for (i, stmt) in body.basic_blocks[bb].statements.iter().enumerate() { if let StatementKind::Assign(rb) = &stmt.kind { if (**rb).0.local.as_u32() == 0u32 { - return Some((stmt.source_info.span, i)); + return Some(( + stmt.source_info.span, + Location { + block: bb, + statement_index: i, + }, + )); } } } let term = body.basic_blocks[bb].terminator(); - if let TerminatorKind::Call { func: _, args: _, destination, target: _, unwind: _, call_source: _, fn_span: _ } = term.kind { + if let TerminatorKind::Call { + func: _, + args: _, + destination, + target, + unwind: _, + call_source: _, + fn_span: _, + } = term.kind + { if destination.local.as_u32() == 0u32 { - return Some((term.source_info.span, body.basic_blocks[bb].statements.len())); + return Some(( + term.source_info.span, + Location { + block: target.unwrap(), + statement_index: 0, + }, + )); } } None From 82eae1f3457aca04d5aac1554023df5748865047 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Fri, 27 Sep 2024 15:19:22 +0900 Subject: [PATCH 08/14] Combined assign and return & disabled simplification for restarted analysis --- src/ai/analysis.rs | 34 ++++-- src/transform.rs | 252 +++++++++++++++++++++++++++++---------------- 2 files changed, 186 insertions(+), 100 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 330c1d7..7ec93a3 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -198,11 +198,14 @@ pub fn analyze( ); } - let AnalyzedBody { - states, - writes_map, - init_state, - } = analyzer.analyze_body(body); + let ( + AnalyzedBody { + states, + writes_map, + init_state, + }, + restarted, + ) = analyzer.analyze_body(body); if conf.print_functions.contains(&tcx.def_path_str(def_id)) { tracing::info!( "{:?}\n{}", @@ -237,6 +240,10 @@ pub fn analyze( } } + if restarted { + stack.clear(); + } + for (loc, sp) in stack.iter() { let must_writes: BTreeSet<_> = states .get(loc) @@ -737,8 +744,9 @@ impl<'a, 'tcx> Analyzer<'a, 'tcx> { } } - fn analyze_body(&mut self, body: &Body<'tcx>) -> AnalyzedBody { + fn analyze_body(&mut self, body: &Body<'tcx>) -> (AnalyzedBody, bool) { let mut start_state = AbsState::bot(); + let mut restarted = false; start_state.writes = MustPathSet::top(); start_state.nulls = MustPathSet::top(); @@ -902,17 +910,21 @@ impl<'a, 'tcx> Analyzer<'a, 'tcx> { } } if restart { + restarted = true; continue 'analysis_loop; } } break (states, writes_map); }; - AnalyzedBody { - states, - writes_map, - init_state, - } + ( + AnalyzedBody { + states, + writes_map, + init_state, + }, + restarted, + ) } pub fn expands_path(&self, place: &AbsPath) -> Vec { diff --git a/src/transform.rs b/src/transform.rs index 431a28f..841d9b5 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -233,10 +233,7 @@ fn transform( let mut visitor = BodyVisitor::new(tcx); visitor.visit_body(body); - let mut pass_visitor = PassVisitor::new(tcx); - pass_visitor.visit_body(body); - - let passes = BTreeSet::from_iter(pass_visitor.passes.iter()); + let passes = BTreeSet::from_iter(visitor.passes.iter()); let mut ret_call_spans = BTreeSet::new(); let mut call_spans = BTreeSet::new(); @@ -393,13 +390,14 @@ fn transform( let post_span = post_span.with_hi(post_span.hi() + BytePos(1)); let rv = format!("{}rv___{}", pre_s, post_s); - let rv = if let Some((_, wbret)) = - wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) - { - func.return_value(Some(rv), Some(wbret)) - } else { - func.return_value(Some(rv), None) - }; + let rv = func.return_value( + Some(rv), + wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + None, + ); fix( post_span, format!("; return {};{}", rv, if arm { " }" } else { "" }), @@ -508,6 +506,62 @@ fn transform( let span = body.value.span.with_lo(pos).with_hi(pos); fix(span, local_vars); + for ret in visitor.returns.iter() { + let Return { span, value } = ret; + if ret_call_spans.contains(span) || ret_to_ref_spans.contains_key(span) { + continue; + } + + let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); + + let mut assign_before_ret = None; + for assign in visitor.assigns.iter() { + let Assign { + name: _, + value: _, + span: sp, + } = assign; + let sp2 = sp.between(*span); + if source_map + .span_to_snippet(sp2) + .unwrap() + .chars() + .filter(|c| !c.is_whitespace()) + .count() + == 0 + { + assign_before_ret = Some(assign); + break; + } + } + + let mut lit_map = None; + + if let Some(assign_before_ret) = assign_before_ret { + let Assign { + name, + value, + span: sp, + } = assign_before_ret; + + if let Some(spans) = ref_to_spans.get_mut(name) { + spans.retain(|span| !sp.contains(*span)); + lit_map = Some((name, value)); + fix(*sp, "".to_string()); + } + } + + let ret_v = func.return_value( + orig, + wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + lit_map, + ); + fix(*span, format!("return {}", ret_v)); + } + for param in func.params() { if let Some(spans) = ref_to_spans.get(¶m.name) { for span in spans { @@ -542,33 +596,18 @@ fn transform( let post_s = source_map.span_to_snippet(*s).unwrap(); rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); } - let rv = if let Some((_, wbret)) = - wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) - { - func.return_value(Some(rv), Some(wbret)) - } else { - func.return_value(Some(rv), None) - }; + let rv = func.return_value( + Some(rv), + wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + None, + ); fix(*span, format!("return {}", rv)); } - for ret in visitor.returns.clone() { - let Return { span, value } = ret; - if ret_call_spans.contains(&span) || ret_to_ref_spans.contains_key(&span) { - continue; - } - let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); - let ret_v = if let Some((_, wbret)) = - wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) - { - func.return_value(orig, Some(wbret)) - } else { - func.return_value(orig, None) - }; - fix(span, format!("return {}", ret_v)); - } - if func.is_unit { let pos = body.value.span.hi() - BytePos(1); let span = body.value.span.with_lo(pos).with_hi(pos); @@ -586,8 +625,8 @@ fn transform( } if !skip { - let ret_v = func.return_value(None, None); - fix(span, ret_v); + let ret_v = func.return_value(None, None, None); + fix(span, format!("\t{}\n", ret_v)); } } } @@ -755,26 +794,36 @@ impl Func { &self, orig: Option, wbret: Option<&(Vec, Vec)>, + lit_map: Option<(&String, &String)>, ) -> String { let mut values = vec![]; if let Some((_, i)) = &self.first_return { let orig = orig.unwrap(); let param = &self.index_map[i]; + let name = lit_map + .and_then(|(n, v)| { + if *n == param.name { + Some((*v).clone()) + } else { + None + } + }) + .unwrap_or(format!("{}___v", param.name)); let v = if let Some((may, must)) = wbret { if must.contains(¶m.name) { - format!("Ok({}___v)", param.name) + format!("Ok({})", name) } else if !may.contains(¶m.name) { format!("Err({})", orig) } else { format!( - "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", - param.name, orig + "if {0}___s {{ Ok({1}) }} else {{ Err({2}) }}", + param.name, name, orig ) } } else { format!( - "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", - param.name, orig + "if {0}___s {{ Ok({1}) }} else {{ Err({2}) }}", + param.name, name, orig ) }; values.push(v); @@ -783,18 +832,33 @@ impl Func { } for i in &self.remaining_return { let param = &self.index_map[i]; + let name = lit_map + .and_then(|(n, v)| { + if *n == param.name { + Some((*v).clone()) + } else { + None + } + }) + .unwrap_or(format!("{}___v", param.name)); let v = if param.must { - format!("{}___v", param.name) + name } else if let Some((may, must)) = wbret { if must.contains(¶m.name) { - format!("Some ({}___v)", param.name) + format!("Some ({})", name) } else if !may.contains(¶m.name) { "None".to_string() } else { - format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) + format!( + "if {0}___s {{ Some({1}) }} else {{ None }}", + param.name, name + ) } } else { - format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) + format!( + "if {0}___s {{ Some({1}) }} else {{ None }}", + param.name, name + ) }; values.push(v); } @@ -833,44 +897,11 @@ struct Ref { name: String, } -struct PassVisitor<'tcx> { - tcx: TyCtxt<'tcx>, - passes: Vec, -} - -impl<'tcx> PassVisitor<'tcx> { - fn new(tcx: TyCtxt<'tcx>) -> Self { - Self { - tcx, - passes: vec![], - } - } -} - -impl<'tcx> HVisitor<'tcx> for PassVisitor<'tcx> { - type NestedFilter = nested_filter::OnlyBodies; - - fn nested_visit_map(&mut self) -> Self::Map { - self.tcx.hir() - } - - fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { - let source_map = self.tcx.sess.source_map(); - if let ExprKind::Path(p) = expr.kind { - if let Ok(code) = source_map.span_to_snippet(p.qself_span()) { - match get_parent(expr.hir_id, self.tcx) { - Some(e) => { - if let ExprKind::Unary(UnOp::Deref, _) = e.kind { - } else { - self.passes.push(code) - } - } - None => self.passes.push(code), - } - } - } - rustc_hir::intravisit::walk_expr(self, expr); - } +#[derive(Debug)] +struct Assign { + name: String, + value: String, + span: Span, } struct BodyVisitor<'tcx> { @@ -878,6 +909,8 @@ struct BodyVisitor<'tcx> { returns: Vec, calls: Vec, refs: Vec, + passes: Vec, + assigns: Vec, } impl<'tcx> BodyVisitor<'tcx> { @@ -887,6 +920,8 @@ impl<'tcx> BodyVisitor<'tcx> { returns: vec![], calls: vec![], refs: vec![], + passes: vec![], + assigns: vec![], } } } @@ -933,6 +968,48 @@ impl<'tcx> BodyVisitor<'tcx> { }; self.calls.push(call); } + + fn visit_expr_path(&mut self, expr: &'tcx Expr<'tcx>, path: QPath<'tcx>) { + let source_map = self.tcx.sess.source_map(); + if let Ok(code) = source_map.span_to_snippet(path.qself_span()) { + match get_parent(expr.hir_id, self.tcx) { + Some(e) => { + if let ExprKind::Unary(UnOp::Deref, _) = e.kind { + } else { + self.passes.push(code) + } + } + None => self.passes.push(code), + } + } + } + + fn visit_expr_unary(&mut self, expr: &'tcx Expr<'tcx>, e: &'tcx Expr<'tcx>) { + let source_map = self.tcx.sess.source_map(); + self.refs.push(Ref { + hir_id: expr.hir_id, + span: expr.span, + name: source_map.span_to_snippet(e.span).unwrap(), + }) + } + + fn visit_expr_assign( + &mut self, + expr: &'tcx Expr<'tcx>, + lhs: &'tcx Expr<'tcx>, + rhs: &'tcx Expr<'tcx>, + ) { + let source_map = self.tcx.sess.source_map(); + if let ExprKind::Unary(UnOp::Deref, e) = lhs.kind { + if let ExprKind::Lit(_) = rhs.kind { + self.assigns.push(Assign { + name: source_map.span_to_snippet(e.span).unwrap(), + value: source_map.span_to_snippet(rhs.span).unwrap(), + span: expr.span.with_hi(expr.span.hi() + BytePos(1)), + }); + } + } + } } impl<'tcx> HVisitor<'tcx> for BodyVisitor<'tcx> { @@ -943,15 +1020,12 @@ impl<'tcx> HVisitor<'tcx> for BodyVisitor<'tcx> { } fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { - let source_map = self.tcx.sess.source_map(); match expr.kind { - ExprKind::Unary(UnOp::Deref, e) => self.refs.push(Ref { - hir_id: expr.hir_id, - span: expr.span, - name: source_map.span_to_snippet(e.span).unwrap(), - }), ExprKind::Ret(e) => self.visit_expr_ret(expr, e), ExprKind::Call(callee, args) => self.visit_expr_call(expr, callee, args), + ExprKind::Path(path) => self.visit_expr_path(expr, path), + ExprKind::Unary(UnOp::Deref, e) => self.visit_expr_unary(expr, e), + ExprKind::Assign(lhs, rhs, _) => self.visit_expr_assign(expr, lhs, rhs), _ => {} } rustc_hir::intravisit::walk_expr(self, expr); From 37ec1033768c3edc302f89030858460023824376 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Fri, 27 Sep 2024 22:04:25 +0900 Subject: [PATCH 09/14] Disabled all simplications for restarted analysis --- src/ai/analysis.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 7ec93a3..14f75db 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -167,6 +167,9 @@ pub fn analyze( let mut is_units = BTreeMap::new(); let mut rcfws = BTreeMap::new(); + + let mut has_restarted = false; + for id in &po { let def_ids = &elems[id]; let recursive = if def_ids.len() == 1 { @@ -206,6 +209,7 @@ pub fn analyze( }, restarted, ) = analyzer.analyze_body(body); + has_restarted |= restarted; if conf.print_functions.contains(&tcx.def_path_str(def_id)) { tracing::info!( "{:?}\n{}", @@ -240,10 +244,6 @@ pub fn analyze( } } - if restarted { - stack.clear(); - } - for (loc, sp) in stack.iter() { let must_writes: BTreeSet<_> = states .get(loc) @@ -411,6 +411,11 @@ pub fn analyze( } } + if has_restarted { + wbrets.clear(); + rcfws.clear(); + } + summaries .into_iter() .map(|(def_id, summary)| { From 1c26f76fbbe8df2f382a8917aecdac4900624da8 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Mon, 7 Oct 2024 21:13:26 +0900 Subject: [PATCH 10/14] Removed simplification for -m 1 and Added statistics --- src/ai/analysis.rs | 43 ++++++++++++++++--------------------------- src/transform.rs | 43 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 14f75db..eb51d5a 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -168,8 +168,6 @@ pub fn analyze( let mut rcfws = BTreeMap::new(); - let mut has_restarted = false; - for id in &po { let def_ids = &elems[id]; let recursive = if def_ids.len() == 1 { @@ -201,15 +199,11 @@ pub fn analyze( ); } - let ( - AnalyzedBody { - states, - writes_map, - init_state, - }, - restarted, - ) = analyzer.analyze_body(body); - has_restarted |= restarted; + let AnalyzedBody { + states, + writes_map, + init_state, + } = analyzer.analyze_body(body); if conf.print_functions.contains(&tcx.def_path_str(def_id)) { tracing::info!( "{:?}\n{}", @@ -399,6 +393,11 @@ pub fn analyze( } } } + + if conf.max_loop_head_states <= 1 { + wbrets.clear(); + } + if let Some(n) = &conf.function_times { let mut analysis_times: Vec<_> = analysis_times.into_iter().collect(); analysis_times.sort_by_key(|(_, t)| u128::MAX - *t); @@ -411,11 +410,6 @@ pub fn analyze( } } - if has_restarted { - wbrets.clear(); - rcfws.clear(); - } - summaries .into_iter() .map(|(def_id, summary)| { @@ -749,9 +743,8 @@ impl<'a, 'tcx> Analyzer<'a, 'tcx> { } } - fn analyze_body(&mut self, body: &Body<'tcx>) -> (AnalyzedBody, bool) { + fn analyze_body(&mut self, body: &Body<'tcx>) -> AnalyzedBody { let mut start_state = AbsState::bot(); - let mut restarted = false; start_state.writes = MustPathSet::top(); start_state.nulls = MustPathSet::top(); @@ -915,21 +908,17 @@ impl<'a, 'tcx> Analyzer<'a, 'tcx> { } } if restart { - restarted = true; continue 'analysis_loop; } } break (states, writes_map); }; - ( - AnalyzedBody { - states, - writes_map, - init_state, - }, - restarted, - ) + AnalyzedBody { + states, + writes_map, + init_state, + } } pub fn expands_path(&self, place: &AbsPath) -> Vec { diff --git a/src/transform.rs b/src/transform.rs index 841d9b5..1ef14c5 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -19,6 +19,12 @@ use rustfix::Suggestion; use crate::{ai::analysis::*, compile_util}; +static mut N_MUST: usize = 0; +static mut N_MAY: usize = 0; +static mut N_DIRECT_RETURNS: usize = 0; +static mut N_REMOVED_CHECKS: usize = 0; +static mut N_REMOVED_POINTERS: usize = 0; + pub fn transform_path( path: &Path, params: &OutputParams, @@ -276,7 +282,7 @@ fn transform( } } - for call in visitor.calls { + for call in visitor.calls.clone() { let Call { hir_id, span, @@ -458,6 +464,8 @@ fn transform( continue; } + unsafe { N_REMOVED_CHECKS += 1; } + let pos = span.hi() + BytePos(1); let span = span.with_hi(pos).with_lo(pos); let assign = format!("{0}___s = true;", param.name); @@ -521,6 +529,9 @@ fn transform( value: _, span: sp, } = assign; + if visitor.calls.iter().any(|call| call.span.overlaps(*sp)) { + continue; + } let sp2 = sp.between(*span); if source_map .span_to_snippet(sp2) @@ -566,6 +577,7 @@ fn transform( if let Some(spans) = ref_to_spans.get(¶m.name) { for span in spans { let assign = format!("{}___v", param.name); + unsafe { N_REMOVED_POINTERS += 1; } fix(*span, assign); } } @@ -640,6 +652,13 @@ fn transform( ) }); } + + println!("Number of must write before return simplifications : {}", unsafe { N_MUST }); + println!("Number of must not write before return simplifications : {}", unsafe { N_MAY }); + println!("Number of direct return simplifications : {}", unsafe { N_DIRECT_RETURNS }); + println!("Number of removed checks: {}", unsafe { N_REMOVED_CHECKS }); + println!("Number of removed pointers: {}", unsafe { N_REMOVED_POINTERS }); + suggestions } @@ -803,6 +822,7 @@ impl Func { let name = lit_map .and_then(|(n, v)| { if *n == param.name { + unsafe { N_DIRECT_RETURNS += 1; } Some((*v).clone()) } else { None @@ -811,8 +831,10 @@ impl Func { .unwrap_or(format!("{}___v", param.name)); let v = if let Some((may, must)) = wbret { if must.contains(¶m.name) { + unsafe { N_MUST += 1; } format!("Ok({})", name) } else if !may.contains(¶m.name) { + unsafe { N_MAY += 1; } format!("Err({})", orig) } else { format!( @@ -835,6 +857,7 @@ impl Func { let name = lit_map .and_then(|(n, v)| { if *n == param.name { + unsafe { N_DIRECT_RETURNS += 1; } Some((*v).clone()) } else { None @@ -845,8 +868,10 @@ impl Func { name } else if let Some((may, must)) = wbret { if must.contains(¶m.name) { + unsafe { N_MUST += 1; } format!("Some ({})", name) } else if !may.contains(¶m.name) { + unsafe { N_MAY += 1; } "None".to_string() } else { format!( @@ -876,7 +901,7 @@ struct Return { value: Option, } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Call { hir_id: HirId, span: Span, @@ -1001,13 +1026,11 @@ impl<'tcx> BodyVisitor<'tcx> { ) { let source_map = self.tcx.sess.source_map(); if let ExprKind::Unary(UnOp::Deref, e) = lhs.kind { - if let ExprKind::Lit(_) = rhs.kind { - self.assigns.push(Assign { - name: source_map.span_to_snippet(e.span).unwrap(), - value: source_map.span_to_snippet(rhs.span).unwrap(), - span: expr.span.with_hi(expr.span.hi() + BytePos(1)), - }); - } + self.assigns.push(Assign { + name: source_map.span_to_snippet(e.span).unwrap(), + value: source_map.span_to_snippet(rhs.span).unwrap(), + span: expr.span.with_hi(expr.span.hi() + BytePos(1)), + }); } } } @@ -1220,10 +1243,12 @@ fn generate_set_flag( let rcfw = &rcfws.get(arg); if let Some(rcfw) = rcfw { if rcfw.iter().any(|sp| span.contains(*sp)) { + unsafe { N_REMOVED_CHECKS += 1; } return "".to_string(); } } return format!("{}___s = true;", arg); } + unsafe { N_REMOVED_CHECKS += 1; } "".to_string() } From a6c1254327db57a49c853a4f32d549e326153d5b Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Tue, 8 Oct 2024 18:19:39 +0900 Subject: [PATCH 11/14] Removed all simplifications for -m 1 --- src/ai/analysis.rs | 1 + src/transform.rs | 58 +++++++++++++++++++++++++++++++++++----------- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index eb51d5a..ae30429 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -396,6 +396,7 @@ pub fn analyze( if conf.max_loop_head_states <= 1 { wbrets.clear(); + rcfws.clear(); } if let Some(n) = &conf.function_times { diff --git a/src/transform.rs b/src/transform.rs index 1ef14c5..f72444f 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -464,7 +464,9 @@ fn transform( continue; } - unsafe { N_REMOVED_CHECKS += 1; } + unsafe { + N_REMOVED_CHECKS += 1; + } let pos = span.hi() + BytePos(1); let span = span.with_hi(pos).with_lo(pos); @@ -577,7 +579,9 @@ fn transform( if let Some(spans) = ref_to_spans.get(¶m.name) { for span in spans { let assign = format!("{}___v", param.name); - unsafe { N_REMOVED_POINTERS += 1; } + unsafe { + N_REMOVED_POINTERS += 1; + } fix(*span, assign); } } @@ -653,11 +657,21 @@ fn transform( }); } - println!("Number of must write before return simplifications : {}", unsafe { N_MUST }); - println!("Number of must not write before return simplifications : {}", unsafe { N_MAY }); - println!("Number of direct return simplifications : {}", unsafe { N_DIRECT_RETURNS }); + println!( + "Number of must write before return simplifications : {}", + unsafe { N_MUST } + ); + println!( + "Number of must not write before return simplifications : {}", + unsafe { N_MAY } + ); + println!("Number of direct return simplifications : {}", unsafe { + N_DIRECT_RETURNS + }); println!("Number of removed checks: {}", unsafe { N_REMOVED_CHECKS }); - println!("Number of removed pointers: {}", unsafe { N_REMOVED_POINTERS }); + println!("Number of removed pointers: {}", unsafe { + N_REMOVED_POINTERS + }); suggestions } @@ -822,7 +836,9 @@ impl Func { let name = lit_map .and_then(|(n, v)| { if *n == param.name { - unsafe { N_DIRECT_RETURNS += 1; } + unsafe { + N_DIRECT_RETURNS += 1; + } Some((*v).clone()) } else { None @@ -831,10 +847,14 @@ impl Func { .unwrap_or(format!("{}___v", param.name)); let v = if let Some((may, must)) = wbret { if must.contains(¶m.name) { - unsafe { N_MUST += 1; } + unsafe { + N_MUST += 1; + } format!("Ok({})", name) } else if !may.contains(¶m.name) { - unsafe { N_MAY += 1; } + unsafe { + N_MAY += 1; + } format!("Err({})", orig) } else { format!( @@ -857,7 +877,9 @@ impl Func { let name = lit_map .and_then(|(n, v)| { if *n == param.name { - unsafe { N_DIRECT_RETURNS += 1; } + unsafe { + N_DIRECT_RETURNS += 1; + } Some((*v).clone()) } else { None @@ -868,10 +890,14 @@ impl Func { name } else if let Some((may, must)) = wbret { if must.contains(¶m.name) { - unsafe { N_MUST += 1; } + unsafe { + N_MUST += 1; + } format!("Some ({})", name) } else if !may.contains(¶m.name) { - unsafe { N_MAY += 1; } + unsafe { + N_MAY += 1; + } "None".to_string() } else { format!( @@ -1243,12 +1269,16 @@ fn generate_set_flag( let rcfw = &rcfws.get(arg); if let Some(rcfw) = rcfw { if rcfw.iter().any(|sp| span.contains(*sp)) { - unsafe { N_REMOVED_CHECKS += 1; } + unsafe { + N_REMOVED_CHECKS += 1; + } return "".to_string(); } } return format!("{}___s = true;", arg); } - unsafe { N_REMOVED_CHECKS += 1; } + unsafe { + N_REMOVED_CHECKS += 1; + } "".to_string() } From 7c38b923d9391f6e3ab774e547efe3fe7ee83413 Mon Sep 17 00:00:00 2001 From: Jaemin Hong Date: Tue, 5 Nov 2024 06:38:21 +0000 Subject: [PATCH 12/14] summary refactoring --- src/ai/analysis.rs | 153 ++++++++++++++++++++------------------------ src/bin/nopcrat.rs | 19 +++--- src/compile_util.rs | 50 ++++++++++++++- src/sampling.rs | 8 +-- src/transform.rs | 111 +++++++++++++++----------------- 5 files changed, 183 insertions(+), 158 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index ae30429..c4b5916 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -4,6 +4,7 @@ use std::{ path::Path, }; +use compile_util::LoHi; use etrace::some_or; use rustc_abi::VariantIdx; use rustc_hir::{ @@ -47,15 +48,24 @@ impl Default for AnalysisConfig { } } -pub type OutputParams = BTreeMap>; +pub type AnalysisResult = BTreeMap; -// Write Before RETurn s -pub type Wbrets = BTreeMap, BTreeSet)>; +#[derive(Debug, Serialize, Deserialize)] +pub struct FnAnalysisRes { + pub output_params: Vec, + pub wbrs: Vec, + pub rcfws: Rcfws, +} -// Removable checks for Write s -pub type Rcfws = BTreeMap>; +#[derive(Debug, Serialize, Deserialize)] +pub struct WriteBeforeReturn { + pub span: LoHi, + pub mays: BTreeSet, + pub musts: BTreeSet, +} -pub type AnalysisResult = (OutputParams, BTreeMap); +// Removable checks for Write s +pub type Rcfws = BTreeMap>; pub fn analyze_path(path: &Path, conf: &AnalysisConfig) -> AnalysisResult { analyze_input(compile_util::path_to_input(path), conf) @@ -70,19 +80,18 @@ pub fn analyze_input(input: Input, conf: &AnalysisConfig) -> AnalysisResult { compile_util::run_compiler(config, |tcx| { analyze(tcx, conf) .into_iter() - .filter_map(|(def_id, (_, params, wbret, rcfw))| { - if params.is_empty() { + .filter_map(|(def_id, (_, res))| { + if res.output_params.is_empty() { None } else { - Some((tcx.def_path_str(def_id), (params, wbret, rcfw))) + Some((tcx.def_path_str(def_id), res)) } }) .collect::>() }) .unwrap() .into_iter() - .map(|(k, (v1, v2, v3))| ((k.clone(), v1), (k.clone(), (v2, v3)))) - .unzip() + .collect() } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -95,7 +104,7 @@ enum Write { pub fn analyze( tcx: TyCtxt<'_>, conf: &AnalysisConfig, -) -> BTreeMap, Wbrets, Rcfws)> { +) -> BTreeMap { let hir = tcx.hir(); let mut call_graph = BTreeMap::new(); @@ -162,8 +171,8 @@ pub fn analyze( let mut call_args_map = BTreeMap::new(); let mut analysis_times: BTreeMap<_, u128> = BTreeMap::new(); - let mut wbrets: BTreeMap = BTreeMap::new(); - let mut wbbbrets: BTreeMap>> = BTreeMap::new(); + let mut wbrs: BTreeMap> = BTreeMap::new(); + let mut bb_musts: BTreeMap>> = BTreeMap::new(); let mut is_units = BTreeMap::new(); let mut rcfws = BTreeMap::new(); @@ -219,8 +228,8 @@ pub fn analyze( let ret_location = return_location(body); - let mut wbret = BTreeMap::new(); - let mut wbbbret = BTreeMap::new(); + let mut wbr = vec![]; + let mut bb_must = BTreeMap::new(); let mut stack = vec![]; @@ -228,55 +237,47 @@ pub fn analyze( if let Some((ret_loc_assign0, ret_loc)) = exists_assign0(body, ret_location.block) { - stack.push((ret_loc, ret_loc_assign0.data())); + stack.push((ret_loc, ret_loc_assign0)); } else if let Some(v) = body.basic_blocks.predecessors().get(ret_location.block) { for i in v { if let Some((sp, ret_loc)) = exists_assign0(body, *i) { - stack.push((ret_loc, sp.data())); + stack.push((ret_loc, sp)); } } } + let empty_map = BTreeMap::new(); for (loc, sp) in stack.iter() { - let must_writes: BTreeSet<_> = states + let writes: Vec<_> = states .get(loc) - .cloned() - .unwrap_or_default() + .unwrap_or(&empty_map) .values() - .fold(None, |acc: Option>, st: &AbsState| { + .map(|st| st.writes.as_set()) + .collect(); + let musts: BTreeSet<_> = writes + .iter() + .copied() + .fold(None, |acc: Option>, ws| { Some(match acc { - Some(acc) => { - acc.intersection(st.writes.as_set()).cloned().collect() - } - None => st.writes.as_set().clone(), + Some(acc) => acc.intersection(ws).cloned().collect(), + None => ws.clone(), }) }) .unwrap_or_default() .iter() .map(|p| p.base() - 1) .collect(); - - let may_writes: BTreeSet<_> = states - .get(loc) - .cloned() - .unwrap_or_default() - .values() - .flat_map(|st| st.writes.as_set()) - .map(|p| p.base() - 1) - .collect(); - - wbret.insert( - unsafe { std::mem::transmute(*sp) }, - (may_writes, must_writes.clone()), - ); - - wbbbret.insert(loc.block, must_writes); + let mays: BTreeSet<_> = + writes.into_iter().flatten().map(|p| p.base() - 1).collect(); + let span = LoHi::from_span(*sp); + bb_must.insert(loc.block, musts.clone()); + wbr.push(WriteBeforeReturn { span, mays, musts }); } } - wbrets.insert(*def_id, wbret); - wbbbrets.insert(*def_id, wbbbret); + wbrs.insert(*def_id, wbr); + bb_musts.insert(*def_id, bb_must); is_units.insert(*def_id, stack.is_empty()); let mut return_states = ret_location @@ -322,47 +323,39 @@ pub fn analyze( } let body = tcx.optimized_mir(*def_id); - let wbbbret = &wbbbrets[def_id]; + let bb_must = &bb_musts[def_id]; let mut rcfw: Rcfws = BTreeMap::new(); if !is_units[def_id] { - for p in output_params.iter() { + for p in &output_params { let OutputParam { index, - must: _, - return_values: _, complete_writes, + .. } = p; - for complete_write in complete_writes.iter() { + for complete_write in complete_writes { let CompleteWrite { block, statement_index, - write_arg: _, + .. } = complete_write; let mut stack = vec![BasicBlock::from_usize(*block)]; - let mut visited = BTreeSet::from_iter(stack.clone()); + let mut visited: BTreeSet<_> = stack.iter().cloned().collect(); - let success = loop { - if let Some(block) = stack.pop() { - if let Some(ws) = wbbbret.get(&block) { - if !ws.contains(index) { + let always_write = loop { + if let Some(bb) = stack.pop() { + if let Some(musts) = bb_must.get(&bb) { + if !musts.contains(index) { break false; } } - let bbd = &body.basic_blocks[block]; - let term = bbd.terminator(); - - match term.kind { - TerminatorKind::Return => (), - _ => { - for bb in term.successors() { - if !visited.contains(&bb) { - visited.insert(bb); - stack.push(bb); - } - } + let term = body.basic_blocks[bb].terminator(); + for bb in term.successors() { + if !visited.contains(&bb) { + visited.insert(bb); + stack.push(bb); } } } else { @@ -370,15 +363,12 @@ pub fn analyze( } }; - if success { + if always_write { let location = Location { block: BasicBlock::from_usize(*block), statement_index: *statement_index, }; - let span = unsafe { - std::mem::transmute(body.source_info(location).span.data()) - }; - + let span = LoHi::from_span(body.source_info(location).span); let entry = rcfw.entry(*index); entry.or_default().insert(span); } @@ -395,7 +385,7 @@ pub fn analyze( } if conf.max_loop_head_states <= 1 { - wbrets.clear(); + wbrs.clear(); rcfws.clear(); } @@ -415,15 +405,14 @@ pub fn analyze( .into_iter() .map(|(def_id, summary)| { let output_params = output_params_map.remove(&def_id).unwrap(); - ( - def_id, - ( - summary, - output_params, - wbrets.get(&def_id).cloned().unwrap_or_default(), - rcfws.get(&def_id).cloned().unwrap_or_default(), - ), - ) + let wbrs = wbrs.remove(&def_id).unwrap_or_default(); + let rcfws = rcfws.remove(&def_id).unwrap_or_default(); + let res = FnAnalysisRes { + output_params, + wbrs, + rcfws, + }; + (def_id, (summary, res)) }) .collect() } diff --git a/src/bin/nopcrat.rs b/src/bin/nopcrat.rs index df7a814..5340b3f 100644 --- a/src/bin/nopcrat.rs +++ b/src/bin/nopcrat.rs @@ -98,16 +98,16 @@ fn main() { }; if args.verbose { - for (func, params) in &analysis_result.0 { + for (func, res) in &analysis_result { println!("{}", func); - for param in params { + for param in &res.output_params { println!(" {:?}", param); } } } if args.sample_negative { - let mut fns = sampling::sample_from_path(path, &analysis_result.0); + let mut fns = sampling::sample_from_path(path, &analysis_result); fns.shuffle(&mut thread_rng()); for f in fns.iter().take(10) { println!("{:?}", f); @@ -116,9 +116,8 @@ fn main() { } if args.sample_may || args.sample_must { let mut params: Vec<_> = analysis_result - .0 .iter() - .filter(|(_, params)| params.iter().any(|p| p.must == args.sample_must)) + .filter(|(_, res)| res.output_params.iter().any(|p| p.must == args.sample_must)) .collect(); params.shuffle(&mut thread_rng()); for (f, ps) in params.iter().take(10) { @@ -128,16 +127,14 @@ fn main() { } if args.use_analysis_result.is_none() { - let fns = analysis_result.0.len(); + let fns = analysis_result.len(); let musts = analysis_result - .0 .values() - .map(|v| v.iter().filter(|p| p.must).count()) + .map(|res| res.output_params.iter().filter(|p| p.must).count()) .sum::(); let mays = analysis_result - .0 .values() - .map(|v| v.iter().filter(|p| !p.must).count()) + .map(|res| res.output_params.iter().filter(|p| !p.must).count()) .sum::(); println!("{} {} {}", fns, musts, mays); } @@ -151,7 +148,7 @@ fn main() { return; } - transform::transform_path(path, &analysis_result.0, &analysis_result.1); + transform::transform_path(path, &analysis_result); } fn clear_dir(path: &Path) { diff --git a/src/compile_util.rs b/src/compile_util.rs index 1178f05..41f8c32 100644 --- a/src/compile_util.rs +++ b/src/compile_util.rs @@ -22,9 +22,57 @@ use rustc_session::{ use rustc_span::{ edition::Edition, source_map::{FileName, SourceMap}, - RealFileName, Span, + BytePos, RealFileName, Span, SpanData, SyntaxContext, }; use rustfix::{LinePosition, LineRange, Replacement, Snippet, Solution, Suggestion}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct LoHi { + lo: u32, + hi: u32, +} + +impl LoHi { + #[inline] + fn new(lo: u32, hi: u32) -> Self { + Self { lo, hi } + } + + #[inline] + pub fn from_span(span: Span) -> Self { + assert!(span.ctxt().is_root()); + assert!(span.parent().is_none()); + Self::new(span.lo().0, span.hi().0) + } + + #[inline] + pub fn from_span_data(span: SpanData) -> Self { + assert!(span.ctxt.is_root()); + assert!(span.parent.is_none()); + Self::new(span.lo.0, span.hi.0) + } + + #[inline] + pub fn to_span(self) -> Span { + Span::new( + BytePos(self.lo), + BytePos(self.hi), + SyntaxContext::root(), + None, + ) + } + + #[inline] + pub fn to_span_data(self) -> SpanData { + SpanData { + lo: BytePos(self.lo), + hi: BytePos(self.hi), + ctxt: SyntaxContext::root(), + parent: None, + } + } +} pub fn run_compiler) -> R + Send>(config: Config, f: F) -> Option { rustc_driver::catch_fatal_errors(|| { diff --git a/src/sampling.rs b/src/sampling.rs index ec882f4..e0a6a41 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -8,17 +8,17 @@ use rustc_middle::{ }; use rustc_session::config::Input; -use crate::{ai::analysis::OutputParams, compile_util}; +use crate::{ai::analysis::AnalysisResult, compile_util}; -pub fn sample_from_path(path: &Path, res: &OutputParams) -> Vec { +pub fn sample_from_path(path: &Path, res: &AnalysisResult) -> Vec { sample_from_input(compile_util::path_to_input(path), res) } -pub fn sample_from_code(code: &str, res: &OutputParams) -> Vec { +pub fn sample_from_code(code: &str, res: &AnalysisResult) -> Vec { sample_from_input(compile_util::str_to_input(code), res) } -fn sample_from_input(input: Input, res: &OutputParams) -> Vec { +fn sample_from_input(input: Input, res: &AnalysisResult) -> Vec { let config = compile_util::make_config(input); compile_util::run_compiler(config, |tcx| { let hir = tcx.hir(); diff --git a/src/transform.rs b/src/transform.rs index f72444f..e89381c 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -14,7 +14,7 @@ use rustc_middle::{ mir::{BasicBlock, TerminatorKind}, ty::TyCtxt, }; -use rustc_span::{def_id::DefId, source_map::SourceMap, BytePos, Span, SpanData}; +use rustc_span::{def_id::DefId, source_map::SourceMap, BytePos, Span}; use rustfix::Suggestion; use crate::{ai::analysis::*, compile_util}; @@ -25,22 +25,17 @@ static mut N_DIRECT_RETURNS: usize = 0; static mut N_REMOVED_CHECKS: usize = 0; static mut N_REMOVED_POINTERS: usize = 0; -pub fn transform_path( - path: &Path, - params: &OutputParams, - extra_info: &BTreeMap, -) { +pub fn transform_path(path: &Path, analysis_result: &AnalysisResult) { let input = compile_util::path_to_input(path); let config = compile_util::make_config(input); let suggestions = - compile_util::run_compiler(config, |tcx| transform(tcx, params, extra_info)).unwrap(); + compile_util::run_compiler(config, |tcx| transform(tcx, analysis_result)).unwrap(); compile_util::apply_suggestions(&suggestions); } fn transform( tcx: TyCtxt<'_>, - param_map: &OutputParams, - extra_info: &BTreeMap, + analysis_result: &AnalysisResult, ) -> BTreeMap> { let hir = tcx.hir(); let source_map = tcx.sess.source_map(); @@ -60,7 +55,7 @@ fn transform( } let mut funcs = BTreeMap::new(); - let mut wbrets = BTreeMap::new(); + let mut wbrs = BTreeMap::new(); let mut rcfws = BTreeMap::new(); for id in hir.items() { let item = hir.item(id); @@ -69,10 +64,11 @@ fn transform( }; let def_id = id.owner_id.to_def_id(); let name = tcx.def_path_str(def_id); - let params = some_or!(param_map.get(&name), continue); + let fn_analysis_result = some_or!(analysis_result.get(&name), continue); let body = hir.body(body_id); let mir_body = tcx.optimized_mir(def_id); - let index_map: BTreeMap<_, _> = params + let index_map: BTreeMap<_, _> = fn_analysis_result + .output_params .iter() .map(|param| { let OutputParam { @@ -150,48 +146,42 @@ fn transform( }) .collect(); - if let Some((wbret, _)) = extra_info.get(&name) { - wbrets.insert( - def_id, - wbret - .clone() - .into_iter() - .map(|(sp, (may, must))| { - ( - unsafe { std::mem::transmute::(sp) }.span(), - ( - may.iter() - .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) - .collect::>(), - must.iter() - .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) - .collect::>(), - ), - ) - }) - .collect::>(), - ); - } - - if let Some((_, rcfw)) = extra_info.get(&name) { - rcfws.insert( - def_id, - rcfw.clone() - .into_iter() - .map(|(index, spans)| { + wbrs.insert( + def_id, + fn_analysis_result + .wbrs + .iter() + .map(|wbr| { + ( + wbr.span.to_span(), ( - index_map.get(&index).cloned().unwrap().name, - spans + wbr.mays .iter() - .map(|sp| { - unsafe { std::mem::transmute::(*sp) }.span() - }) - .collect(), - ) - }) - .collect::>>(), - ); - } + .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) + .collect::>(), + wbr.musts + .iter() + .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) + .collect::>(), + ), + ) + }) + .collect::>(), + ); + + rcfws.insert( + def_id, + fn_analysis_result + .rcfws + .iter() + .map(|(index, spans)| { + ( + index_map.get(index).cloned().unwrap().name, + spans.iter().map(|sp| sp.to_span()).collect(), + ) + }) + .collect::>>(), + ); let hir_id_map: BTreeMap<_, _> = index_map .values() @@ -199,7 +189,7 @@ fn transform( .map(|param| (param.hir_id, param)) .collect(); let mut remaining_return: Vec<_> = index_map.keys().copied().collect(); - let first_return = SuccValue::find(params); + let first_return = SuccValue::find(&fn_analysis_result.output_params); if let Some((_, first)) = &first_return { remaining_return.retain(|i| i != first); } @@ -244,6 +234,7 @@ fn transform( let mut ret_call_spans = BTreeSet::new(); let mut call_spans = BTreeSet::new(); + // in deref expression, pointer name to expression span let mut ref_to_spans: BTreeMap> = BTreeMap::new(); let mut ret_to_ref_spans: BTreeMap> = BTreeMap::new(); @@ -398,7 +389,7 @@ fn transform( let rv = format!("{}rv___{}", pre_s, post_s); let rv = func.return_value( Some(rv), - wbrets[&def_id] + wbrs[&def_id] .iter() .find(|(sp, _)| span.contains(*sp)) .map(|r| &r.1), @@ -449,7 +440,7 @@ fn transform( let mut unremovable = BTreeSet::new(); for param in func.params() { - let rcfw = &rcfw.get(¶m.name); + let rcfw = rcfw.get(¶m.name); for span in ¶m.writes { if let Some(rcfw) = rcfw { @@ -566,7 +557,7 @@ fn transform( let ret_v = func.return_value( orig, - wbrets[&def_id] + wbrs[&def_id] .iter() .find(|(sp, _)| span.contains(*sp)) .map(|r| &r.1), @@ -614,7 +605,7 @@ fn transform( } let rv = func.return_value( Some(rv), - wbrets[&def_id] + wbrs[&def_id] .iter() .find(|(sp, _)| span.contains(*sp)) .map(|r| &r.1), @@ -826,7 +817,7 @@ impl Func { fn return_value( &self, orig: Option, - wbret: Option<&(Vec, Vec)>, + wbr: Option<&(Vec, Vec)>, lit_map: Option<(&String, &String)>, ) -> String { let mut values = vec![]; @@ -845,7 +836,7 @@ impl Func { } }) .unwrap_or(format!("{}___v", param.name)); - let v = if let Some((may, must)) = wbret { + let v = if let Some((may, must)) = wbr { if must.contains(¶m.name) { unsafe { N_MUST += 1; @@ -888,7 +879,7 @@ impl Func { .unwrap_or(format!("{}___v", param.name)); let v = if param.must { name - } else if let Some((may, must)) = wbret { + } else if let Some((may, must)) = wbr { if must.contains(¶m.name) { unsafe { N_MUST += 1; From 3c16ea67fd54460ded0d339738933d52efa2055e Mon Sep 17 00:00:00 2001 From: Jaemin Hong Date: Tue, 5 Nov 2024 07:20:52 +0000 Subject: [PATCH 13/14] simplify option & count pointer def/use --- src/bin/nopcrat.rs | 4 +- src/transform.rs | 120 +++++++++++++++++++++++++-------------------- 2 files changed, 69 insertions(+), 55 deletions(-) diff --git a/src/bin/nopcrat.rs b/src/bin/nopcrat.rs index 5340b3f..00a6a95 100644 --- a/src/bin/nopcrat.rs +++ b/src/bin/nopcrat.rs @@ -25,6 +25,8 @@ struct Args { #[arg(short, long)] transform: bool, + #[arg(long)] + simplify: bool, #[arg(short, long)] size: bool, #[arg(long)] @@ -148,7 +150,7 @@ fn main() { return; } - transform::transform_path(path, &analysis_result); + transform::transform_path(path, &analysis_result, args.simplify); } fn clear_dir(path: &Path) { diff --git a/src/transform.rs b/src/transform.rs index e89381c..1e04ade 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -23,20 +23,29 @@ static mut N_MUST: usize = 0; static mut N_MAY: usize = 0; static mut N_DIRECT_RETURNS: usize = 0; static mut N_REMOVED_CHECKS: usize = 0; -static mut N_REMOVED_POINTERS: usize = 0; -pub fn transform_path(path: &Path, analysis_result: &AnalysisResult) { +#[derive(Default, Clone, Copy)] +struct Counter { + removed_pointer_defs: usize, + removed_pointer_uses: usize, +} + +pub fn transform_path(path: &Path, analysis_result: &AnalysisResult, simplify: bool) { let input = compile_util::path_to_input(path); let config = compile_util::make_config(input); let suggestions = - compile_util::run_compiler(config, |tcx| transform(tcx, analysis_result)).unwrap(); + compile_util::run_compiler(config, |tcx| transform(tcx, analysis_result, simplify)) + .unwrap(); compile_util::apply_suggestions(&suggestions); } fn transform( tcx: TyCtxt<'_>, analysis_result: &AnalysisResult, + simplify: bool, ) -> BTreeMap> { + let mut counter = Counter::default(); + let hir = tcx.hir(); let source_map = tcx.sess.source_map(); @@ -237,38 +246,42 @@ fn transform( // in deref expression, pointer name to expression span let mut ref_to_spans: BTreeMap> = BTreeMap::new(); + // in deref expression in return, return span to pointer name and deref span let mut ret_to_ref_spans: BTreeMap> = BTreeMap::new(); + // in deref expression in call, pointer name and deref span let mut ref_and_call_spans = vec![]; - if let Some(func) = curr { - let hirids = BTreeSet::from_iter( - visitor - .calls - .iter() - .filter_map(|c| funcs.get(&c.callee).map(|_| c.hir_id)), - ); - let params = BTreeSet::from_iter(func.params().map(|p| p.name.clone())); - - for rf in visitor.refs { - let Ref { hir_id, span, name } = rf; - if params.contains(&name) && !passes.contains(&name) { - if let Some(expr) = get_parent_call(hir_id, tcx) { - if hirids.contains(&expr.hir_id) { - ref_and_call_spans.push((name, span)); + if simplify { + if let Some(func) = curr { + let hirids = BTreeSet::from_iter( + visitor + .calls + .iter() + .filter_map(|c| funcs.get(&c.callee).map(|_| c.hir_id)), + ); + let params = BTreeSet::from_iter(func.params().map(|p| p.name.clone())); + + for rf in visitor.refs { + let Ref { hir_id, span, name } = rf; + if params.contains(&name) && !passes.contains(&name) { + if let Some(expr) = get_parent_call(hir_id, tcx) { + if hirids.contains(&expr.hir_id) { + ref_and_call_spans.push((name, span)); + continue; + } + } + + if let Some(expr) = get_parent_return(hir_id, tcx) { + ret_to_ref_spans + .entry(expr.span) + .or_default() + .push((name, span)); continue; } - } - if let Some(expr) = get_parent_return(hir_id, tcx) { - ret_to_ref_spans - .entry(expr.span) - .or_default() - .push((name, span)); - continue; + ref_to_spans.entry(name.clone()).or_default().push(span); } - - ref_to_spans.entry(name.clone()).or_default().push(span); } } } @@ -286,7 +299,7 @@ fn transform( let mut args = args.clone(); for arg in args.iter_mut() { - for (name, span) in ref_and_call_spans.iter() { + for (name, span) in &ref_and_call_spans { if arg.span.contains(*span) { let pre_span = arg.span.with_hi(span.lo()); let post_span = arg.span.with_lo(span.hi()); @@ -295,6 +308,7 @@ fn transform( let post_s = source_map.span_to_snippet(post_span).unwrap(); arg.code = format!("{}{}___v{}", pre_s, name, post_s); + counter.removed_pointer_uses += 1; } } } @@ -470,7 +484,7 @@ fn transform( .params() .map(|param| { if param.must || (!unremovable.contains(¶m.name)) { - if passes.contains(¶m.name) { + if passes.contains(¶m.name) || !simplify { format!( " let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ @@ -478,13 +492,14 @@ fn transform( param.name, param.ty, ) } else { + counter.removed_pointer_defs += 1; format!( " let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", param.name, param.ty, ) } - } else if passes.contains(¶m.name) { + } else if passes.contains(¶m.name) || !simplify { format!( " let mut {0}___s: bool = false; \ @@ -493,6 +508,7 @@ fn transform( param.name, param.ty, ) } else { + counter.removed_pointer_defs += 1; format!( " let mut {0}___s: bool = false; \ @@ -570,26 +586,23 @@ fn transform( if let Some(spans) = ref_to_spans.get(¶m.name) { for span in spans { let assign = format!("{}___v", param.name); - unsafe { - N_REMOVED_POINTERS += 1; - } + counter.removed_pointer_uses += 1; fix(*span, assign); } } } - for (span, ss) in ret_to_ref_spans.iter() { + for (span, mut ss) in ret_to_ref_spans { let mut post_spans = vec![]; - let mut sorted_ss = ss.clone(); - sorted_ss.sort_by_key(|x| x.1); + ss.sort_by_key(|x| x.1); - let s = sorted_ss[0].1; + let s = ss[0].1; let pre_span = span.with_hi(s.lo()).with_lo(span.lo() + BytePos(6)); post_spans.push(span.with_lo(s.hi())); - for (i, (_, s)) in sorted_ss[1..].iter().enumerate() { + for (i, (_, s)) in ss[1..].iter().enumerate() { post_spans[i] = post_spans[i].with_hi(s.lo()); post_spans.push(span.with_lo(s.hi())); } @@ -597,11 +610,15 @@ fn transform( let pre_s = source_map.span_to_snippet(pre_span).unwrap(); let post_s = source_map.span_to_snippet(post_spans[0]).unwrap(); - let mut rv = format!("{}{}___v{}", pre_s, sorted_ss[0].0, post_s); + let mut rv = format!("{}{}___v{}", pre_s, ss[0].0, post_s); + counter.removed_pointer_uses += 1; for (i, s) in post_spans[1..].iter().enumerate() { let post_s = source_map.span_to_snippet(*s).unwrap(); - rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); + rv.push_str(&ss[i + 1].0); + rv.push_str("___v"); + rv.push_str(&post_s); + counter.removed_pointer_uses += 1; } let rv = func.return_value( Some(rv), @@ -612,7 +629,7 @@ fn transform( None, ); - fix(*span, format!("return {}", rv)); + fix(span, format!("return {}", rv)); } if func.is_unit { @@ -648,21 +665,19 @@ fn transform( }); } + println!("Removed pointer defs: {}", counter.removed_pointer_defs); + println!("Removed pointer uses: {}", counter.removed_pointer_uses); + println!("Must write before return simplifications : {}", unsafe { + N_MUST + }); println!( - "Number of must write before return simplifications : {}", - unsafe { N_MUST } - ); - println!( - "Number of must not write before return simplifications : {}", + "Must not write before return simplifications : {}", unsafe { N_MAY } ); - println!("Number of direct return simplifications : {}", unsafe { + println!("Direct return simplifications : {}", unsafe { N_DIRECT_RETURNS }); - println!("Number of removed checks: {}", unsafe { N_REMOVED_CHECKS }); - println!("Number of removed pointers: {}", unsafe { - N_REMOVED_POINTERS - }); + println!("Removed checks: {}", unsafe { N_REMOVED_CHECKS }); suggestions } @@ -1268,8 +1283,5 @@ fn generate_set_flag( } return format!("{}___s = true;", arg); } - unsafe { - N_REMOVED_CHECKS += 1; - } "".to_string() } From bda75eabc80f18ec71018815e549427c856a8e0a Mon Sep 17 00:00:00 2001 From: Jaemin Hong Date: Tue, 5 Nov 2024 08:25:46 +0000 Subject: [PATCH 14/14] count other simplifications --- src/transform.rs | 124 ++++++++++++++++++++++------------------------- 1 file changed, 57 insertions(+), 67 deletions(-) diff --git a/src/transform.rs b/src/transform.rs index 1e04ade..e7c2c85 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -19,15 +19,16 @@ use rustfix::Suggestion; use crate::{ai::analysis::*, compile_util}; -static mut N_MUST: usize = 0; -static mut N_MAY: usize = 0; -static mut N_DIRECT_RETURNS: usize = 0; -static mut N_REMOVED_CHECKS: usize = 0; - #[derive(Default, Clone, Copy)] struct Counter { + simplify: bool, removed_pointer_defs: usize, removed_pointer_uses: usize, + direct_returns: usize, + success_returns: usize, + failure_returns: usize, + removed_flag_sets: usize, + removed_flag_defs: usize, } pub fn transform_path(path: &Path, analysis_result: &AnalysisResult, simplify: bool) { @@ -44,7 +45,10 @@ fn transform( analysis_result: &AnalysisResult, simplify: bool, ) -> BTreeMap> { - let mut counter = Counter::default(); + let mut counter = Counter { + simplify, + ..Counter::default() + }; let hir = tcx.hir(); let source_map = tcx.sess.source_map(); @@ -321,7 +325,7 @@ fn transform( let assign_map = curr.map(|c| c.assign_map(span)).unwrap_or_default(); let mut mtch = func.first_return.and_then(|(_, first)| { - let set_flag = generate_set_flag(&span, &first, rcfw, &assign_map); + let set_flag = generate_set_flag(&span, &first, rcfw, &assign_map, &mut counter); func.call_match(&args, set_flag) }); @@ -335,7 +339,7 @@ fn transform( let fail = "Err(_) => "; let (_, i) = func.first_return.as_ref().unwrap(); let arg = &args[*i]; - let set_flag = generate_set_flag(&span, i, rcfw, &assign_map); + let set_flag = generate_set_flag(&span, i, rcfw, &assign_map, &mut counter); let assign = if arg.code.contains("&mut ") { format!(" *({}) = v___; {}", arg.code, set_flag) } else { @@ -408,6 +412,7 @@ fn transform( .find(|(sp, _)| span.contains(*sp)) .map(|r| &r.1), None, + &mut counter, ); fix( post_span, @@ -425,7 +430,12 @@ fn transform( let set_flags = func .remaining_return .iter() - .map(|i| (*i, generate_set_flag(&span, i, rcfw, &assign_map))) + .map(|i| { + ( + *i, + generate_set_flag(&span, i, rcfw, &assign_map, &mut counter), + ) + }) .collect(); let mut assign = func.call_assign(&args, &set_flags); @@ -458,7 +468,8 @@ fn transform( for span in ¶m.writes { if let Some(rcfw) = rcfw { - if rcfw.iter().any(|sp| span.contains(*sp)) { + if rcfw.iter().any(|sp| span.contains(*sp)) && simplify { + counter.removed_flag_sets += 1; continue; } } @@ -469,10 +480,6 @@ fn transform( continue; } - unsafe { - N_REMOVED_CHECKS += 1; - } - let pos = span.hi() + BytePos(1); let span = span.with_hi(pos).with_lo(pos); let assign = format!("{0}___s = true;", param.name); @@ -483,7 +490,10 @@ fn transform( let local_vars: String = func .params() .map(|param| { - if param.must || (!unremovable.contains(¶m.name)) { + if param.must || (!unremovable.contains(¶m.name) && simplify) { + if !param.must { + counter.removed_flag_defs += 1; + } if passes.contains(¶m.name) || !simplify { format!( " @@ -533,22 +543,18 @@ fn transform( let mut assign_before_ret = None; for assign in visitor.assigns.iter() { - let Assign { - name: _, - value: _, - span: sp, - } = assign; - if visitor.calls.iter().any(|call| call.span.overlaps(*sp)) { + if visitor + .calls + .iter() + .any(|call| call.span.overlaps(assign.span)) + { continue; } - let sp2 = sp.between(*span); if source_map - .span_to_snippet(sp2) + .span_to_snippet(assign.span.between(*span)) .unwrap() .chars() - .filter(|c| !c.is_whitespace()) - .count() - == 0 + .all(|c| c.is_whitespace()) { assign_before_ret = Some(assign); break; @@ -578,6 +584,7 @@ fn transform( .find(|(sp, _)| span.contains(*sp)) .map(|r| &r.1), lit_map, + &mut counter, ); fix(*span, format!("return {}", ret_v)); } @@ -627,6 +634,7 @@ fn transform( .find(|(sp, _)| span.contains(*sp)) .map(|r| &r.1), None, + &mut counter, ); fix(span, format!("return {}", rv)); @@ -649,7 +657,7 @@ fn transform( } if !skip { - let ret_v = func.return_value(None, None, None); + let ret_v = func.return_value(None, None, None, &mut counter); fix(span, format!("\t{}\n", ret_v)); } } @@ -667,17 +675,11 @@ fn transform( println!("Removed pointer defs: {}", counter.removed_pointer_defs); println!("Removed pointer uses: {}", counter.removed_pointer_uses); - println!("Must write before return simplifications : {}", unsafe { - N_MUST - }); - println!( - "Must not write before return simplifications : {}", - unsafe { N_MAY } - ); - println!("Direct return simplifications : {}", unsafe { - N_DIRECT_RETURNS - }); - println!("Removed checks: {}", unsafe { N_REMOVED_CHECKS }); + println!("Direct returns: {}", counter.direct_returns); + println!("Success returns: {}", counter.success_returns); + println!("Failure returns: {}", counter.failure_returns); + println!("Removed flag sets: {}", counter.removed_flag_sets); + println!("Removed flag defs: {}", counter.removed_flag_defs); suggestions } @@ -834,6 +836,7 @@ impl Func { orig: Option, wbr: Option<&(Vec, Vec)>, lit_map: Option<(&String, &String)>, + counter: &mut Counter, ) -> String { let mut values = vec![]; if let Some((_, i)) = &self.first_return { @@ -841,10 +844,8 @@ impl Func { let param = &self.index_map[i]; let name = lit_map .and_then(|(n, v)| { - if *n == param.name { - unsafe { - N_DIRECT_RETURNS += 1; - } + if *n == param.name && counter.simplify { + counter.direct_returns += 1; Some((*v).clone()) } else { None @@ -852,15 +853,11 @@ impl Func { }) .unwrap_or(format!("{}___v", param.name)); let v = if let Some((may, must)) = wbr { - if must.contains(¶m.name) { - unsafe { - N_MUST += 1; - } + if must.contains(¶m.name) && counter.simplify { + counter.success_returns += 1; format!("Ok({})", name) - } else if !may.contains(¶m.name) { - unsafe { - N_MAY += 1; - } + } else if !may.contains(¶m.name) && counter.simplify { + counter.failure_returns += 1; format!("Err({})", orig) } else { format!( @@ -882,10 +879,8 @@ impl Func { let param = &self.index_map[i]; let name = lit_map .and_then(|(n, v)| { - if *n == param.name { - unsafe { - N_DIRECT_RETURNS += 1; - } + if *n == param.name && counter.simplify { + counter.direct_returns += 1; Some((*v).clone()) } else { None @@ -895,15 +890,11 @@ impl Func { let v = if param.must { name } else if let Some((may, must)) = wbr { - if must.contains(¶m.name) { - unsafe { - N_MUST += 1; - } - format!("Some ({})", name) - } else if !may.contains(¶m.name) { - unsafe { - N_MAY += 1; - } + if must.contains(¶m.name) && counter.simplify { + counter.success_returns += 1; + format!("Some({})", name) + } else if !may.contains(¶m.name) && counter.simplify { + counter.failure_returns += 1; "None".to_string() } else { format!( @@ -1270,14 +1261,13 @@ fn generate_set_flag( i: &usize, rcfws: &BTreeMap>, assign_map: &BTreeMap, + counter: &mut Counter, ) -> String { if let Some(arg) = assign_map.get(i) { let rcfw = &rcfws.get(arg); if let Some(rcfw) = rcfw { - if rcfw.iter().any(|sp| span.contains(*sp)) { - unsafe { - N_REMOVED_CHECKS += 1; - } + if rcfw.iter().any(|sp| span.contains(*sp)) && counter.simplify { + counter.removed_flag_sets += 1; return "".to_string(); } }