From a781f97ddf65120a360733883f3e04f834e44fcb Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Tue, 30 Jul 2024 13:38:03 +0900 Subject: [PATCH 1/2] must output parameters --- src/transform.rs | 203 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 192 insertions(+), 11 deletions(-) diff --git a/src/transform.rs b/src/transform.rs index 92c52c8..db66f4a 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -7,7 +7,7 @@ use etrace::some_or; use rustc_ast::LitKind; use rustc_hir::{ def::Res, intravisit::Visitor as HVisitor, BinOpKind, Expr, ExprKind, FnRetTy, HirId, ItemKind, - MutTy, Node, PatKind, QPath, TyKind, + MutTy, Node, PatKind, QPath, TyKind, UnOp, }; use rustc_middle::{ hir::nested_filter, @@ -178,9 +178,49 @@ fn transform( let mut visitor = BodyVisitor::new(tcx); visitor.visit_body(body); + let mut pass_visitor = PassVisitor::new(tcx); + pass_visitor.visit_body(body); + + let passes = BTreeSet::from_iter(pass_visitor.passes.iter()); + let mut ret_call_spans = BTreeSet::new(); let mut call_spans = BTreeSet::new(); + let mut ref_to_spans: BTreeMap> = BTreeMap::new(); + + let mut ret_to_ref_spans: BTreeMap> = BTreeMap::new(); + + let mut ref_and_call_spans = vec![]; + + if let Some(func) = curr { + let hirids = BTreeSet::from_iter( + visitor + .calls + .iter() + .filter_map(|c| funcs.get(&c.callee).map(|_| c.hir_id)), + ); + let params = BTreeSet::from_iter(func.params().map(|p| p.name.clone())); + + for rf in visitor.refs { + let Ref { hir_id, span, name } = rf; + if params.contains(&name) && !passes.contains(&name) { + if let Some(expr) = get_parent_call(hir_id, tcx) { + if hirids.contains(&expr.hir_id) { + ref_and_call_spans.push((name, span)); + continue; + } + } + + if let Some(expr) = get_parent_return(hir_id, tcx) { + ret_to_ref_spans.entry(expr.span).or_default().push((name, span)); + continue; + } + + ref_to_spans.entry(name.clone()).or_default().push(span); + } + } + } + for call in visitor.calls { let Call { hir_id, @@ -191,6 +231,22 @@ fn transform( let func = some_or!(funcs.get(&callee), continue); call_spans.insert(span); + let mut args = args.clone(); + + for arg in args.iter_mut() { + for (name, span) in ref_and_call_spans.iter() { + if arg.span.contains(*span) { + let pre_span = arg.span.with_hi(span.lo()); + let post_span = arg.span.with_lo(span.hi()); + + let pre_s = source_map.span_to_snippet(pre_span).unwrap(); + let post_s = source_map.span_to_snippet(post_span).unwrap(); + + arg.code = format!("{}{}___v{}", pre_s, name, post_s); + } + } + } + for index in func.index_map.keys() { let span = to_comma(args[*index].span, source_map); fix(span, "".to_string()); @@ -321,20 +377,37 @@ fn transform( .params() .map(|param| { if param.must { - format!( - " + if passes.contains(¶m.name) { + format!( + " let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ let mut {0}: *mut {1} = &mut {0}___v;", - param.name, param.ty, - ) + param.name, param.ty, + ) + } else { + format!( + " + let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", + param.name, param.ty, + ) + } } else { - format!( - " + if passes.contains(¶m.name) { + format!( + " let mut {0}___s: bool = false; \ let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ let mut {0}: *mut {1} = &mut {0}___v;", - param.name, param.ty, - ) + param.name, param.ty, + ) + } else { + format!( + " + let mut {0}___s: bool = false; \ + let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", + param.name, param.ty, + ) + } } }) .collect(); @@ -355,9 +428,48 @@ fn transform( } } + for param in func.params() { + if let Some(spans) = ref_to_spans.get(¶m.name) { + for span in spans { + let assign = format!("{}___v", param.name); + fix(span.clone(), assign); + } + } + } + + for (span, ss) in ret_to_ref_spans.iter() { + let mut post_spans = vec![]; + + let mut sorted_ss = ss.clone(); + sorted_ss.sort_by_key(|x| x.1); + + let s = sorted_ss[0].1; + + let pre_span = span.with_hi(s.lo()).with_lo(span.lo() + BytePos(6)); + post_spans.push(span.with_lo(s.hi())); + + for (i, (_, s)) in sorted_ss[1..].iter().enumerate() { + post_spans[i] = post_spans[i].with_hi(s.lo()); + post_spans.push(span.with_lo(s.hi())); + } + + let pre_s = source_map.span_to_snippet(pre_span).unwrap(); + let post_s = source_map.span_to_snippet(post_spans[0]).unwrap(); + + let mut rv = format!("{}{}___v{}", pre_s, sorted_ss[0].0, post_s); + + for (i, s) in post_spans[1..].iter().enumerate() { + let post_s = source_map.span_to_snippet(s.clone()).unwrap(); + rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); + } + let rv = func.return_value(Some(rv)); + + fix(span.clone(), format!("return {}", rv)); + } + for ret in visitor.returns { let Return { span, value } = ret; - if ret_call_spans.contains(&span) { + if ret_call_spans.contains(&span) || ret_to_ref_spans.contains_key(&span) { continue; } let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); @@ -586,16 +698,66 @@ struct Call { args: Vec, } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Arg { span: Span, code: String, } +#[derive(Debug)] +struct Ref { + hir_id: HirId, + span: Span, + name: String, +} + +struct PassVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + passes: Vec, +} + +impl<'tcx> PassVisitor<'tcx> { + fn new(tcx: TyCtxt<'tcx>) -> Self { + Self { + tcx, + passes: vec![], + } + } +} + +impl<'tcx> HVisitor<'tcx> for PassVisitor<'tcx> { + type NestedFilter = nested_filter::OnlyBodies; + + fn nested_visit_map(&mut self) -> Self::Map { + self.tcx.hir() + } + + fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { + let source_map = self.tcx.sess.source_map(); + match expr.kind { + ExprKind::Path(p) => match source_map.span_to_snippet(p.qself_span()) { + Ok(code) => match get_parent(expr.hir_id, self.tcx) { + Some(e) => { + if let ExprKind::Unary(UnOp::Deref, _) = e.kind { + } else { + self.passes.push(code) + } + } + None => self.passes.push(code), + }, + Err(_) => {} + }, + _ => {} + } + rustc_hir::intravisit::walk_expr(self, expr); + } +} + struct BodyVisitor<'tcx> { tcx: TyCtxt<'tcx>, returns: Vec, calls: Vec, + refs: Vec, } impl<'tcx> BodyVisitor<'tcx> { @@ -604,6 +766,7 @@ impl<'tcx> BodyVisitor<'tcx> { tcx, returns: vec![], calls: vec![], + refs: vec![], } } } @@ -660,7 +823,16 @@ impl<'tcx> HVisitor<'tcx> for BodyVisitor<'tcx> { } fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { + let source_map = self.tcx.sess.source_map(); match expr.kind { + ExprKind::Unary(unary, e) => match unary { + UnOp::Deref => self.refs.push(Ref { + hir_id: expr.hir_id, + span: expr.span, + name: source_map.span_to_snippet(e.span).unwrap(), + }), + _ => {} + }, ExprKind::Ret(e) => self.visit_expr_ret(expr, e), ExprKind::Call(callee, args) => self.visit_expr_call(expr, callee, args), _ => {} @@ -777,6 +949,15 @@ fn get_parent_return(hir_id: HirId, tcx: TyCtxt<'_>) -> Option<&Expr<'_>> { } } +fn get_parent_call(hir_id: HirId, tcx: TyCtxt<'_>) -> Option<&Expr<'_>> { + let parent = get_parent(hir_id, tcx)?; + if let ExprKind::Call(_, _) = parent.kind { + Some(parent) + } else { + get_parent_call(parent.hir_id, tcx) + } +} + #[derive(Debug)] struct IfCmpCall { if_span: Span, From 74903e68664ce272a6a5e73c2423a888d4ce8237 Mon Sep 17 00:00:00 2001 From: HoseongLee Date: Wed, 31 Jul 2024 14:01:05 +0900 Subject: [PATCH 2/2] formatting --- src/transform.rs | 66 +++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/src/transform.rs b/src/transform.rs index db66f4a..d9394e4 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -212,7 +212,10 @@ fn transform( } if let Some(expr) = get_parent_return(hir_id, tcx) { - ret_to_ref_spans.entry(expr.span).or_default().push((name, span)); + ret_to_ref_spans + .entry(expr.span) + .or_default() + .push((name, span)); continue; } @@ -391,23 +394,21 @@ fn transform( param.name, param.ty, ) } + } else if passes.contains(¶m.name) { + format!( + " + let mut {0}___s: bool = false; \ + let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ + let mut {0}: *mut {1} = &mut {0}___v;", + param.name, param.ty, + ) } else { - if passes.contains(¶m.name) { - format!( - " - let mut {0}___s: bool = false; \ - let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]); \ - let mut {0}: *mut {1} = &mut {0}___v;", - param.name, param.ty, - ) - } else { - format!( - " - let mut {0}___s: bool = false; \ - let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", - param.name, param.ty, - ) - } + format!( + " + let mut {0}___s: bool = false; \ + let mut {0}___v: {1} = std::mem::transmute([0u8; std::mem::size_of::<{1}>()]);", + param.name, param.ty, + ) } }) .collect(); @@ -432,7 +433,7 @@ fn transform( if let Some(spans) = ref_to_spans.get(¶m.name) { for span in spans { let assign = format!("{}___v", param.name); - fix(span.clone(), assign); + fix(*span, assign); } } } @@ -459,12 +460,12 @@ fn transform( let mut rv = format!("{}{}___v{}", pre_s, sorted_ss[0].0, post_s); for (i, s) in post_spans[1..].iter().enumerate() { - let post_s = source_map.span_to_snippet(s.clone()).unwrap(); + let post_s = source_map.span_to_snippet(*s).unwrap(); rv = format!("{}{}___v{}", rv, sorted_ss[i + 1].0, post_s); } let rv = func.return_value(Some(rv)); - fix(span.clone(), format!("return {}", rv)); + fix(*span, format!("return {}", rv)); } for ret in visitor.returns { @@ -734,9 +735,9 @@ impl<'tcx> HVisitor<'tcx> for PassVisitor<'tcx> { fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { let source_map = self.tcx.sess.source_map(); - match expr.kind { - ExprKind::Path(p) => match source_map.span_to_snippet(p.qself_span()) { - Ok(code) => match get_parent(expr.hir_id, self.tcx) { + if let ExprKind::Path(p) = expr.kind { + if let Ok(code) = source_map.span_to_snippet(p.qself_span()) { + match get_parent(expr.hir_id, self.tcx) { Some(e) => { if let ExprKind::Unary(UnOp::Deref, _) = e.kind { } else { @@ -744,10 +745,8 @@ impl<'tcx> HVisitor<'tcx> for PassVisitor<'tcx> { } } None => self.passes.push(code), - }, - Err(_) => {} - }, - _ => {} + } + } } rustc_hir::intravisit::walk_expr(self, expr); } @@ -825,14 +824,11 @@ impl<'tcx> HVisitor<'tcx> for BodyVisitor<'tcx> { fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { let source_map = self.tcx.sess.source_map(); match expr.kind { - ExprKind::Unary(unary, e) => match unary { - UnOp::Deref => self.refs.push(Ref { - hir_id: expr.hir_id, - span: expr.span, - name: source_map.span_to_snippet(e.span).unwrap(), - }), - _ => {} - }, + ExprKind::Unary(UnOp::Deref, e) => self.refs.push(Ref { + hir_id: expr.hir_id, + span: expr.span, + name: source_map.span_to_snippet(e.span).unwrap(), + }), ExprKind::Ret(e) => self.visit_expr_ret(expr, e), ExprKind::Call(callee, args) => self.visit_expr_call(expr, callee, args), _ => {}