From 82eae1f3457aca04d5aac1554023df5748865047 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Fri, 27 Sep 2024 15:19:22 +0900 Subject: [PATCH] Combined assign and return & disabled simplification for restarted analysis --- src/ai/analysis.rs | 34 ++++-- src/transform.rs | 252 +++++++++++++++++++++++++++++---------------- 2 files changed, 186 insertions(+), 100 deletions(-) diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 330c1d7..7ec93a3 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -198,11 +198,14 @@ pub fn analyze( ); } - let AnalyzedBody { - states, - writes_map, - init_state, - } = analyzer.analyze_body(body); + let ( + AnalyzedBody { + states, + writes_map, + init_state, + }, + restarted, + ) = analyzer.analyze_body(body); if conf.print_functions.contains(&tcx.def_path_str(def_id)) { tracing::info!( "{:?}\n{}", @@ -237,6 +240,10 @@ pub fn analyze( } } + if restarted { + stack.clear(); + } + for (loc, sp) in stack.iter() { let must_writes: BTreeSet<_> = states .get(loc) @@ -737,8 +744,9 @@ impl<'a, 'tcx> Analyzer<'a, 'tcx> { } } - fn analyze_body(&mut self, body: &Body<'tcx>) -> AnalyzedBody { + fn analyze_body(&mut self, body: &Body<'tcx>) -> (AnalyzedBody, bool) { let mut start_state = AbsState::bot(); + let mut restarted = false; start_state.writes = MustPathSet::top(); start_state.nulls = MustPathSet::top(); @@ -902,17 +910,21 @@ impl<'a, 'tcx> Analyzer<'a, 'tcx> { } } if restart { + restarted = true; continue 'analysis_loop; } } break (states, writes_map); }; - AnalyzedBody { - states, - writes_map, - init_state, - } + ( + AnalyzedBody { + states, + writes_map, + init_state, + }, + restarted, + ) } pub fn expands_path(&self, place: &AbsPath) -> Vec { diff --git a/src/transform.rs b/src/transform.rs index 431a28f..841d9b5 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -233,10 +233,7 @@ fn transform( let mut visitor = BodyVisitor::new(tcx); visitor.visit_body(body); - let mut pass_visitor = PassVisitor::new(tcx); - pass_visitor.visit_body(body); - - let passes = BTreeSet::from_iter(pass_visitor.passes.iter()); + let passes = BTreeSet::from_iter(visitor.passes.iter()); let mut ret_call_spans = BTreeSet::new(); let mut call_spans = BTreeSet::new(); @@ -393,13 +390,14 @@ fn transform( let post_span = post_span.with_hi(post_span.hi() + BytePos(1)); let rv = format!("{}rv___{}", pre_s, post_s); - let rv = if let Some((_, wbret)) = - wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) - { - func.return_value(Some(rv), Some(wbret)) - } else { - func.return_value(Some(rv), None) - }; + let rv = func.return_value( + Some(rv), + wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + None, + ); fix( post_span, format!("; return {};{}", rv, if arm { " }" } else { "" }), @@ -508,6 +506,62 @@ fn transform( let span = body.value.span.with_lo(pos).with_hi(pos); fix(span, local_vars); + for ret in visitor.returns.iter() { + let Return { span, value } = ret; + if ret_call_spans.contains(span) || ret_to_ref_spans.contains_key(span) { + continue; + } + + let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); + + let mut assign_before_ret = None; + for assign in visitor.assigns.iter() { + let Assign { + name: _, + value: _, + span: sp, + } = assign; + let sp2 = sp.between(*span); + if source_map + .span_to_snippet(sp2) + .unwrap() + .chars() + .filter(|c| !c.is_whitespace()) + .count() + == 0 + { + assign_before_ret = Some(assign); + break; + } + } + + let mut lit_map = None; + + if let Some(assign_before_ret) = assign_before_ret { + let Assign { + name, + value, + span: sp, + } = assign_before_ret; + + if let Some(spans) = ref_to_spans.get_mut(name) { + spans.retain(|span| !sp.contains(*span)); + lit_map = Some((name, value)); + fix(*sp, "".to_string()); + } + } + + let ret_v = func.return_value( + orig, + wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + lit_map, + ); + fix(*span, format!("return {}", ret_v)); + } + for param in func.params() { if let Some(spans) = ref_to_spans.get(¶m.name) { for span in spans { @@ -542,33 +596,18 @@ fn transform( let post_s = source_map.span_to_snippet(*s).unwrap(); rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); } - let rv = if let Some((_, wbret)) = - wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) - { - func.return_value(Some(rv), Some(wbret)) - } else { - func.return_value(Some(rv), None) - }; + let rv = func.return_value( + Some(rv), + wbrets[&def_id] + .iter() + .find(|(sp, _)| span.contains(*sp)) + .map(|r| &r.1), + None, + ); fix(*span, format!("return {}", rv)); } - for ret in visitor.returns.clone() { - let Return { span, value } = ret; - if ret_call_spans.contains(&span) || ret_to_ref_spans.contains_key(&span) { - continue; - } - let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); - let ret_v = if let Some((_, wbret)) = - wbrets[&def_id].iter().find(|(sp, _)| span.contains(*sp)) - { - func.return_value(orig, Some(wbret)) - } else { - func.return_value(orig, None) - }; - fix(span, format!("return {}", ret_v)); - } - if func.is_unit { let pos = body.value.span.hi() - BytePos(1); let span = body.value.span.with_lo(pos).with_hi(pos); @@ -586,8 +625,8 @@ fn transform( } if !skip { - let ret_v = func.return_value(None, None); - fix(span, ret_v); + let ret_v = func.return_value(None, None, None); + fix(span, format!("\t{}\n", ret_v)); } } } @@ -755,26 +794,36 @@ impl Func { &self, orig: Option, wbret: Option<&(Vec, Vec)>, + lit_map: Option<(&String, &String)>, ) -> String { let mut values = vec![]; if let Some((_, i)) = &self.first_return { let orig = orig.unwrap(); let param = &self.index_map[i]; + let name = lit_map + .and_then(|(n, v)| { + if *n == param.name { + Some((*v).clone()) + } else { + None + } + }) + .unwrap_or(format!("{}___v", param.name)); let v = if let Some((may, must)) = wbret { if must.contains(¶m.name) { - format!("Ok({}___v)", param.name) + format!("Ok({})", name) } else if !may.contains(¶m.name) { format!("Err({})", orig) } else { format!( - "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", - param.name, orig + "if {0}___s {{ Ok({1}) }} else {{ Err({2}) }}", + param.name, name, orig ) } } else { format!( - "if {0}___s {{ Ok({0}___v) }} else {{ Err({1}) }}", - param.name, orig + "if {0}___s {{ Ok({1}) }} else {{ Err({2}) }}", + param.name, name, orig ) }; values.push(v); @@ -783,18 +832,33 @@ impl Func { } for i in &self.remaining_return { let param = &self.index_map[i]; + let name = lit_map + .and_then(|(n, v)| { + if *n == param.name { + Some((*v).clone()) + } else { + None + } + }) + .unwrap_or(format!("{}___v", param.name)); let v = if param.must { - format!("{}___v", param.name) + name } else if let Some((may, must)) = wbret { if must.contains(¶m.name) { - format!("Some ({}___v)", param.name) + format!("Some ({})", name) } else if !may.contains(¶m.name) { "None".to_string() } else { - format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) + format!( + "if {0}___s {{ Some({1}) }} else {{ None }}", + param.name, name + ) } } else { - format!("if {0}___s {{ Some({0}___v) }} else {{ None }}", param.name) + format!( + "if {0}___s {{ Some({1}) }} else {{ None }}", + param.name, name + ) }; values.push(v); } @@ -833,44 +897,11 @@ struct Ref { name: String, } -struct PassVisitor<'tcx> { - tcx: TyCtxt<'tcx>, - passes: Vec, -} - -impl<'tcx> PassVisitor<'tcx> { - fn new(tcx: TyCtxt<'tcx>) -> Self { - Self { - tcx, - passes: vec![], - } - } -} - -impl<'tcx> HVisitor<'tcx> for PassVisitor<'tcx> { - type NestedFilter = nested_filter::OnlyBodies; - - fn nested_visit_map(&mut self) -> Self::Map { - self.tcx.hir() - } - - fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { - let source_map = self.tcx.sess.source_map(); - if let ExprKind::Path(p) = expr.kind { - if let Ok(code) = source_map.span_to_snippet(p.qself_span()) { - match get_parent(expr.hir_id, self.tcx) { - Some(e) => { - if let ExprKind::Unary(UnOp::Deref, _) = e.kind { - } else { - self.passes.push(code) - } - } - None => self.passes.push(code), - } - } - } - rustc_hir::intravisit::walk_expr(self, expr); - } +#[derive(Debug)] +struct Assign { + name: String, + value: String, + span: Span, } struct BodyVisitor<'tcx> { @@ -878,6 +909,8 @@ struct BodyVisitor<'tcx> { returns: Vec, calls: Vec, refs: Vec, + passes: Vec, + assigns: Vec, } impl<'tcx> BodyVisitor<'tcx> { @@ -887,6 +920,8 @@ impl<'tcx> BodyVisitor<'tcx> { returns: vec![], calls: vec![], refs: vec![], + passes: vec![], + assigns: vec![], } } } @@ -933,6 +968,48 @@ impl<'tcx> BodyVisitor<'tcx> { }; self.calls.push(call); } + + fn visit_expr_path(&mut self, expr: &'tcx Expr<'tcx>, path: QPath<'tcx>) { + let source_map = self.tcx.sess.source_map(); + if let Ok(code) = source_map.span_to_snippet(path.qself_span()) { + match get_parent(expr.hir_id, self.tcx) { + Some(e) => { + if let ExprKind::Unary(UnOp::Deref, _) = e.kind { + } else { + self.passes.push(code) + } + } + None => self.passes.push(code), + } + } + } + + fn visit_expr_unary(&mut self, expr: &'tcx Expr<'tcx>, e: &'tcx Expr<'tcx>) { + let source_map = self.tcx.sess.source_map(); + self.refs.push(Ref { + hir_id: expr.hir_id, + span: expr.span, + name: source_map.span_to_snippet(e.span).unwrap(), + }) + } + + fn visit_expr_assign( + &mut self, + expr: &'tcx Expr<'tcx>, + lhs: &'tcx Expr<'tcx>, + rhs: &'tcx Expr<'tcx>, + ) { + let source_map = self.tcx.sess.source_map(); + if let ExprKind::Unary(UnOp::Deref, e) = lhs.kind { + if let ExprKind::Lit(_) = rhs.kind { + self.assigns.push(Assign { + name: source_map.span_to_snippet(e.span).unwrap(), + value: source_map.span_to_snippet(rhs.span).unwrap(), + span: expr.span.with_hi(expr.span.hi() + BytePos(1)), + }); + } + } + } } impl<'tcx> HVisitor<'tcx> for BodyVisitor<'tcx> { @@ -943,15 +1020,12 @@ impl<'tcx> HVisitor<'tcx> for BodyVisitor<'tcx> { } fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { - let source_map = self.tcx.sess.source_map(); match expr.kind { - ExprKind::Unary(UnOp::Deref, e) => self.refs.push(Ref { - hir_id: expr.hir_id, - span: expr.span, - name: source_map.span_to_snippet(e.span).unwrap(), - }), ExprKind::Ret(e) => self.visit_expr_ret(expr, e), ExprKind::Call(callee, args) => self.visit_expr_call(expr, callee, args), + ExprKind::Path(path) => self.visit_expr_path(expr, path), + ExprKind::Unary(UnOp::Deref, e) => self.visit_expr_unary(expr, e), + ExprKind::Assign(lhs, rhs, _) => self.visit_expr_assign(expr, lhs, rhs), _ => {} } rustc_hir::intravisit::walk_expr(self, expr);