From 69456a99b398e46991165d413a8d344c10f86772 Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Wed, 20 Nov 2024 18:10:49 +0000 Subject: [PATCH] [Rust]: Fix calltree discovery logic (#1826) [Rust] Fix calltree discovery logic Signed-off-by: Arthur Chan --- .../rust_function_analyser/src/analyse.rs | 401 ++++++++++++++---- .../rust_function_analyser/src/call_tree.rs | 162 +++++-- 2 files changed, 439 insertions(+), 124 deletions(-) diff --git a/frontends/rust/rust_function_analyser/src/analyse.rs b/frontends/rust/rust_function_analyser/src/analyse.rs index 3efd9144a..1006a87c7 100644 --- a/frontends/rust/rust_function_analyser/src/analyse.rs +++ b/frontends/rust/rust_function_analyser/src/analyse.rs @@ -16,8 +16,10 @@ use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::fs; -use syn::{Expr, ExprBlock, FnArg, ImplItem, ImplItemFn, Item, ItemFn, ItemImpl, ReturnType, Stmt, Visibility}; -use syn::spanned::Spanned; +use syn::{ + punctuated::Punctuated, spanned::Spanned, Expr, ExprBlock, FnArg, ImplItemFn, Item, + ItemFn, Pat, ReturnType, Stmt, Visibility +}; // Base struct for BranchSide array in Branch Profile #[derive(Serialize, Deserialize, Debug, Clone)] @@ -95,7 +97,9 @@ pub struct FunctionAnalyser { pub functions: Vec, pub call_stack: HashMap>, pub reverse_call_map: HashMap, - pub method_impls: HashMap, + pub method_return_types: HashMap<(String, String), String>, + pub variable_types: HashMap, + pub first_pass_complete: bool, } // Major implementation for the AST visiting and analysing through the syn crate @@ -105,7 +109,9 @@ impl FunctionAnalyser { functions: Vec::new(), call_stack: HashMap::new(), reverse_call_map: HashMap::new(), - method_impls: HashMap::new(), + method_return_types: HashMap::new(), + variable_types: HashMap::new(), + first_pass_complete: false, } } @@ -116,16 +122,37 @@ impl FunctionAnalyser { let syntax = syn::parse_file(&file_content) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - // Analyse and retrieve a list of functions/methods with all properties - for item in syntax.items { + // Analyse and retrieve a list of functions/methods return value and impl for processing + self.first_pass_complete = false; + for item in &syntax.items { match item { - Item::Fn(func) => self.visit_function(&func, file_path), - Item::Impl(ItemImpl { self_ty, items, .. }) => { - let parent_name = - format!("{}", quote::ToTokens::to_token_stream(&self_ty)); - for impl_item in items { - if let ImplItem::Fn(method) = impl_item { - self.visit_method(&method, file_path, &parent_name); + syn::Item::Fn(item_fn) => self.visit_function(item_fn, file_path), + syn::Item::Impl(item_impl) => { + if let syn::Type::Path(type_path) = &*item_impl.self_ty { + let impl_type = type_path.path.segments.last().unwrap().ident.to_string(); + for item in &item_impl.items { + if let syn::ImplItem::Fn(method) = item { + self.visit_method(method, file_path, &impl_type); + } + } + } + } + _ => {} + } + } + + // Second pass to handle functions/methods call and process them directly + self.first_pass_complete = true; + for item in &syntax.items { + match item { + syn::Item::Fn(item_fn) => self.visit_function(item_fn, file_path), + syn::Item::Impl(item_impl) => { + if let syn::Type::Path(type_path) = &*item_impl.self_ty { + let impl_type = type_path.path.segments.last().unwrap().ident.to_string(); + for item in &item_impl.items { + if let syn::ImplItem::Fn(method) = item { + self.visit_method(method, file_path, &impl_type); + } } } } @@ -136,39 +163,71 @@ impl FunctionAnalyser { Ok(()) } - // visit implementation to go through all functions from the AST + // visit implementation to go through all functions from the AST in two passes approach pub fn visit_function(&mut self, node: &ItemFn, file: &str) { - let visibility = self.get_visibility(&node.vis); - let (start_line, end_line) = self.get_function_lines(&node.block.brace_token); - self.process_function( - &node.sig.ident.to_string(), - &node.sig.inputs, - &node.sig.output, - &node.block.stmts, - file, - visibility, - start_line, - end_line, - ); + self.extract_parameter_types(&node.sig.inputs); + + if !self.first_pass_complete { + if let syn::ReturnType::Type(_, ty) = &node.sig.output { + if let syn::Type::Path(type_path) = &**ty { + let function_name = node.sig.ident.to_string(); + let return_type = type_path.path.segments.last().unwrap().ident.to_string(); + self.method_return_types.insert(("".to_string(), function_name), return_type); + } + } + } else { + let visibility = self.get_visibility(&node.vis); + let (start_line, end_line) = self.get_function_lines(&node.block.brace_token); + self.process_function( + &node.sig.ident.to_string(), + &node.sig.inputs, + &node.sig.output, + &node.block.stmts, + file, + visibility, + start_line, + end_line, + ); + self.variable_types.clear(); + } } // visit implementation to go through all methods from the AST pub fn visit_method(&mut self, node: &ImplItemFn, file: &str, parent_name: &str) { - let name = format!("{}::{}", parent_name, node.sig.ident); - self.method_impls - .insert(node.sig.ident.to_string(), parent_name.to_string()); - let visibility = self.get_visibility(&node.vis); - let (start_line, end_line) = self.get_function_lines(&node.block.brace_token); - self.process_function( - &name, - &node.sig.inputs, - &node.sig.output, - &node.block.stmts, - file, - visibility, - start_line, - end_line, - ); + let method_name = format!("{}::{}", parent_name, node.sig.ident); + + if !self.first_pass_complete { + let return_type = match &node.sig.output { + syn::ReturnType::Type(_, ty) => match &**ty { + syn::Type::Path(type_path) => type_path + .path + .segments + .last() + .map(|seg| seg.ident.to_string()), + _ => None, + }, + syn::ReturnType::Default => None, + }; + + if let Some(return_type) = return_type { + self.method_return_types + .insert((parent_name.to_string(), method_name), return_type); + } + } else { + self.extract_parameter_types(&node.sig.inputs); + let visibility = self.get_visibility(&node.vis); + let (start_line, end_line) = self.get_function_lines(&node.block.brace_token); + self.process_function( + &method_name, + &node.sig.inputs, + &node.sig.output, + &node.block.stmts, + file, + visibility, + start_line, + end_line, + ); + } } // Internal method to process each functions/methods when going through them in the AST @@ -184,25 +243,58 @@ impl FunctionAnalyser { start_line: usize, end_line: usize, ) { + // Clean function/method name + let cleaned_name = self.clean_function_name(name.to_string()); + // Discover return type of the target function/method let return_type = match output { ReturnType::Default => "void".to_string(), - ReturnType::Type(_, ty) => format!("{}", quote::ToTokens::to_token_stream(&**ty)), - } - .replace(' ', ""); + ReturnType::Type(_, ty) => { + let mut return_type = self.clean_function_name(format!("{}", quote::ToTokens::to_token_stream(&**ty))); + if cleaned_name.contains("::") && return_type == "Self" { + if let Some(pos) = name.rfind("::") { + return_type = name[..pos].to_string(); + } + } + return_type + } + }; // Discover the arg types Vector of the target function/method let arg_types = inputs .iter() .filter_map(|arg| { if let FnArg::Typed(pat) = arg { - Some(format!("{}", quote::ToTokens::to_token_stream(&*pat.ty)).replace(' ', "")) + Some(format!("{}", self.clean_function_name(quote::ToTokens::to_token_stream(&*pat.ty).to_string()))) } else { None } }) .collect::>(); + // Discover the arg names Vector of the target function/method + let arg_names = inputs + .iter() + .filter_map(|arg| { + if let FnArg::Typed(pat) = arg { + if let Pat::Ident(ident) = &*pat.pat { + Some(ident.ident.to_string()) + } else { + None + } + } else { + None + } + }) + .collect::>(); + + // Mapping of argument name and type + let arg_map: HashMap = arg_names + .clone() + .into_iter() + .zip(arg_types.clone().into_iter()) + .collect(); + // Calculate the cyclomatic complexity of the target function/method let complexity = self.calculate_cyclomatic_complexity(stmts); @@ -220,14 +312,14 @@ impl FunctionAnalyser { // Generate branch profiles for the target function/method. The SYN create AST // approach currently only support branching analysis for if statement. - let branch_profiles = self.profile_branches(stmts, file); + let branch_profiles = self.profile_branches(stmts, file, &arg_map); // Extract the callsites and called functions information from the target function/method let mut called_functions = Vec::new(); let mut callsites = Vec::new(); for stmt in stmts { - self.extract_called_functions(stmt, &mut called_functions, &mut callsites, file); + self.extract_called_functions(stmt, &mut called_functions, &mut callsites, file, &arg_map); } called_functions.sort(); @@ -241,8 +333,8 @@ impl FunctionAnalyser { self.functions.push(FunctionInfo { linkage_type: String::new(), constants_touched: Vec::new(), - arg_names: Vec::new(), - name: name.to_string(), + arg_names, + name: cleaned_name, file: file.to_string(), return_type, arg_count: arg_types.len(), @@ -274,9 +366,33 @@ impl FunctionAnalyser { called_functions: &mut Vec, callsites: &mut Vec, file: &str, + arg_map: &HashMap, ) { - if let Stmt::Expr(expr, _) = stmt { - self.extract_from_expr(expr, called_functions, callsites, file); + match stmt { + Stmt::Local(local_stmt) => { + if let Some(init_expr) = &local_stmt.init { + + self.extract_from_expr(&init_expr.expr, called_functions, callsites, file, arg_map); + } + } + + Stmt::Item(item) => { + if let Item::Fn(item_fn) = item { + for stmt in &item_fn.block.stmts { + self.extract_called_functions(stmt, called_functions, callsites, file, arg_map); + } + } + } + + Stmt::Expr(expr, _) => { + self.extract_from_expr(expr, called_functions, callsites, file, arg_map); + } + + Stmt::Macro(macro_stmt) => { + if let Ok(parsed_body) = macro_stmt.mac.parse_body::() { + self.extract_from_expr(&parsed_body, called_functions, callsites, file, arg_map); + } + } } } @@ -291,6 +407,7 @@ impl FunctionAnalyser { called_functions: &mut Vec, callsites: &mut Vec, file: &str, + arg_map: &HashMap, ) { match expr { // General function call @@ -304,17 +421,17 @@ impl FunctionAnalyser { .map(|seg| seg.ident.to_string()) .collect::>() .join("::"); - called_functions.push(full_path.clone()); + called_functions.push(self.clean_function_name(full_path.clone())); let span = call_expr.func.span().start(); callsites.push(CallSite { src: format!("{},{},{}", file, span.line, span.column), - dst: full_path, + dst: self.clean_function_name(full_path), }); } // Handle method/function in arguments for arg in &call_expr.args { - self.extract_from_expr(arg, called_functions, callsites, file); + self.extract_from_expr(arg, called_functions, callsites, file, arg_map); } } @@ -326,23 +443,43 @@ impl FunctionAnalyser { called_functions, callsites, file, + arg_map, ); - // Handle method call + // Determine correct method impl + let receiver_type = self.extract_receiver_type(&method_call.receiver); let method_name = method_call.method.to_string(); - if let Some(impl_name) = self.method_impls.get(&method_name) { - let full_path = format!("{}::{}", impl_name, method_name); - called_functions.push(full_path.clone()); - let span = method_call.span().start(); - callsites.push(CallSite { - src: format!("{},{},{}", file, span.line, span.column), - dst: full_path, - }); - } + let resolved_type = match receiver_type.as_deref() { + Some(typ) => Some(typ.to_string()), + None => { + if let Expr::Path(path) = &*method_call.receiver { + if let Some(ident) = path.path.get_ident() { + arg_map.get(&ident.to_string()).cloned() + } else { + None + } + } else { + None + } + } + }; + + let full_path = match resolved_type { + Some(receiver) => format!("{}::{}", receiver, method_name), + None => method_name.clone(), + }; + + // Store called functions/methods + called_functions.push(self.clean_function_name(full_path.clone())); + let span = method_call.span().start(); + callsites.push(CallSite { + src: format!("{},{},{}", file, span.line, span.column), + dst: self.clean_function_name(full_path), + }); // Handle method/function in arguments for arg in &method_call.args { - self.extract_from_expr(arg, called_functions, callsites, file); + self.extract_from_expr(arg, called_functions, callsites, file, arg_map); } } @@ -352,12 +489,12 @@ impl FunctionAnalyser { match stmt { Stmt::Local(local_stmt) => { if let Some(init_expr) = &local_stmt.init { - self.extract_from_expr(&init_expr.expr, called_functions, callsites, file); + self.extract_from_expr(&init_expr.expr, called_functions, callsites, file, arg_map); } } Stmt::Expr(expr, _) => { - self.extract_from_expr(expr, called_functions, callsites, file); + self.extract_from_expr(expr, called_functions, callsites, file, arg_map); } _ => {} @@ -376,6 +513,7 @@ impl FunctionAnalyser { called_functions, callsites, file, + arg_map, ); if let Some((_, else_expr)) = &if_expr.else_branch { match else_expr.as_ref() { @@ -386,6 +524,7 @@ impl FunctionAnalyser { called_functions, callsites, file, + arg_map, ); } @@ -396,6 +535,7 @@ impl FunctionAnalyser { called_functions, callsites, file, + arg_map, ); } @@ -413,18 +553,19 @@ impl FunctionAnalyser { called_functions, callsites, file, + arg_map, ); } } // Await statement Expr::Await(await_expr) => { - self.extract_from_expr(&await_expr.base, called_functions, callsites, file); + self.extract_from_expr(&await_expr.base, called_functions, callsites, file, arg_map); } // Try statment Expr::Try(try_expr) => { - self.extract_from_expr(&try_expr.expr, called_functions, callsites, file); + self.extract_from_expr(&try_expr.expr, called_functions, callsites, file, arg_map); } // While loop @@ -438,6 +579,7 @@ impl FunctionAnalyser { called_functions, callsites, file, + arg_map, ); } @@ -452,6 +594,7 @@ impl FunctionAnalyser { called_functions, callsites, file, + arg_map, ); } @@ -462,79 +605,124 @@ impl FunctionAnalyser { called_functions, callsites, file, + arg_map, ); } // Struct context Expr::Struct(struct_expr) => { for field in &struct_expr.fields { - self.extract_from_expr(&field.expr, called_functions, callsites, file); + self.extract_from_expr(&field.expr, called_functions, callsites, file, arg_map); } } // Indexing for vector and array Expr::Index(index_expr) => { - self.extract_from_expr(&index_expr.expr, called_functions, callsites, file); - self.extract_from_expr(&index_expr.index, called_functions, callsites, file); + self.extract_from_expr(&index_expr.expr, called_functions, callsites, file, arg_map); + self.extract_from_expr(&index_expr.index, called_functions, callsites, file, arg_map); } // Impl field accessing Expr::Field(field_expr) => { - self.extract_from_expr(&field_expr.base, called_functions, callsites, file); + self.extract_from_expr(&field_expr.base, called_functions, callsites, file, arg_map); } // Tuple handling Expr::Tuple(tuple_expr) => { for elem in &tuple_expr.elems { - self.extract_from_expr(elem, called_functions, callsites, file); + self.extract_from_expr(elem, called_functions, callsites, file, arg_map); } } // Macro invocations Expr::Macro(macro_expr) => { if let Ok(parsed_body) = macro_expr.mac.parse_body::() { - self.extract_from_expr(&parsed_body, called_functions, callsites, file); + self.extract_from_expr(&parsed_body, called_functions, callsites, file, arg_map); } } // Return statement Expr::Return(return_expr) => { if let Some(expr) = &return_expr.expr { - self.extract_from_expr(expr, called_functions, callsites, file); + self.extract_from_expr(expr, called_functions, callsites, file, arg_map); } } // Assigning statement Expr::Assign(assign_expr) => { - self.extract_from_expr(&assign_expr.left, called_functions, callsites, file); - self.extract_from_expr(&assign_expr.right, called_functions, callsites, file); + self.extract_from_expr(&assign_expr.left, called_functions, callsites, file, arg_map); + self.extract_from_expr(&assign_expr.right, called_functions, callsites, file, arg_map); } // Binary comparison Expr::Binary(binary_expr) => { - self.extract_from_expr(&binary_expr.left, called_functions, callsites, file); - self.extract_from_expr(&binary_expr.right, called_functions, callsites, file); + self.extract_from_expr(&binary_expr.left, called_functions, callsites, file, arg_map); + self.extract_from_expr(&binary_expr.right, called_functions, callsites, file, arg_map); } // Unary Comparison Expr::Unary(unary_expr) => { - self.extract_from_expr(&unary_expr.expr, called_functions, callsites, file); + self.extract_from_expr(&unary_expr.expr, called_functions, callsites, file, arg_map); } // Paren Statement Expr::Paren(paren_expr) => { - self.extract_from_expr(&paren_expr.expr, called_functions, callsites, file); + self.extract_from_expr(&paren_expr.expr, called_functions, callsites, file, arg_map); } // Grouping process Expr::Group(group_expr) => { - self.extract_from_expr(&group_expr.expr, called_functions, callsites, file); + self.extract_from_expr(&group_expr.expr, called_functions, callsites, file, arg_map); } _ => {} } } + // Helper method to determine correcct parameter type for method call impl discovery + fn extract_parameter_types(&mut self, inputs: &Punctuated) { + for input in inputs { + if let syn::FnArg::Typed(pat_type) = input { + if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { + let variable_name = pat_ident.ident.to_string(); + let variable_type = match &*pat_type.ty { + syn::Type::Path(type_path) => type_path + .path + .segments + .last() + .map(|seg| seg.ident.to_string()), + _ => None, + }; + if let Some(var_type) = variable_type { + self.variable_types.insert(variable_name, var_type); + } + } + } + } + } + + // Helper method to determine correct receiver type of a method call + fn extract_receiver_type(&self, receiver: &syn::Expr) -> Option { + match receiver { + // For variable or parameter calls + Expr::Path(path_expr) => { + let variable_name = path_expr.path.segments.last()?.ident.to_string(); + self.variable_types.get(&variable_name).cloned() + } + + // For chained calls + Expr::MethodCall(method_call) => { + let receiver_type = self.extract_receiver_type(&method_call.receiver)?; + let method_name = method_call.method.to_string(); + self.method_return_types + .get(&(receiver_type, method_name)) + .cloned() + } + + _ => None, + } + } + // Transform Visibility enum of rust functions/methods into string fn get_visibility(&self, vis: &Visibility) -> String { match vis { @@ -560,7 +748,7 @@ impl FunctionAnalyser { // Internal helper method for extracing branch profile of a function // Currently, the SYN crate AST approach only support branching with IF statement // TODO Find other ways to extract and handle of other branching statements - fn profile_branches(&self, stmts: &[Stmt], file: &str) -> Vec { + fn profile_branches(&self, stmts: &[Stmt], file: &str, arg_map: &HashMap) -> Vec { let mut branch_profiles = Vec::new(); for stmt in stmts { @@ -573,11 +761,11 @@ impl FunctionAnalyser { if_expr.if_token.span.start().column ); - let mut branch_sides = vec![self.extract_branch_side(&if_expr.then_branch, file)]; + let mut branch_sides = vec![self.extract_branch_side(&if_expr.then_branch, file, arg_map)]; if let Some((_, else_block)) = &if_expr.else_branch { if let Expr::Block(block_expr) = &**else_block { - branch_sides.push(self.extract_branch_side(&block_expr.block, file)); + branch_sides.push(self.extract_branch_side(&block_expr.block, file, arg_map)); } } @@ -594,10 +782,10 @@ impl FunctionAnalyser { } // Internal helper for profile_branches to retrieve information of the branch side for the if statement - fn extract_branch_side(&self, block: &syn::Block, file: &str) -> BranchSide { + fn extract_branch_side(&self, block: &syn::Block, file: &str, arg_map: &HashMap) -> BranchSide { let mut branch_side_funcs = vec![]; for stmt in &block.stmts { - self.extract_called_functions(stmt, &mut branch_side_funcs, &mut vec![], file); + self.extract_called_functions(stmt, &mut branch_side_funcs, &mut vec![], file, arg_map); } let span = block.brace_token.span.open().start(); @@ -629,12 +817,49 @@ impl FunctionAnalyser { called.clone() }) .collect(); + + for callsite in &mut func.callsites { + if !callsite.dst.contains("::") { + let qualified_name = format!("{}::{}", impl_name, callsite.dst); + if function_set.contains(&qualified_name) { + callsite.dst = qualified_name; + } + } + } } func.function_uses = *self.reverse_call_map.get(&func.name).unwrap_or(&0); } } + // Internal helper method to clean function name + fn clean_function_name(&self, input: String) -> String { + let mut result = String::new(); + let mut inside_angle_brackets = false; + + for c in input.chars() { + match c { + '<' => inside_angle_brackets = true, + '>' => inside_angle_brackets = false, + '\'' => continue, + _ if inside_angle_brackets || c.is_whitespace() => continue, + _ => result.push(c), + } + } + + + // Trim unncessary prefix + let trimmed_result = if result.starts_with("&mut") { + result[4..].to_string() + } else if input.starts_with('&') { + result[1..].to_string() + } else { + result + }; + + trimmed_result + } + // Internal entry method for calculating function depth recursively fn calculate_function_depth(&self, name: &str) -> usize { let mut visited = HashSet::new(); diff --git a/frontends/rust/rust_function_analyser/src/call_tree.rs b/frontends/rust/rust_function_analyser/src/call_tree.rs index 61fb1e171..58eb7dc32 100644 --- a/frontends/rust/rust_function_analyser/src/call_tree.rs +++ b/frontends/rust/rust_function_analyser/src/call_tree.rs @@ -50,7 +50,7 @@ pub fn generate_call_trees( writeln!(output, "fuzz_target {} linenumber=-1", fuzz_file)?; // Extract functions from the fuzz_target macro in the harness - let called_functions = extract_called_functions(fuzz_file)?; + let called_functions = extract_called_functions(fuzz_file, functions)?; // Build the call tree let mut visited = HashSet::new(); @@ -120,24 +120,103 @@ fn find_fuzzing_harnesses(dir: &str) -> io::Result> { } // Extract all functions in the fuzz_target macro in the fuzzing harnesses -fn extract_called_functions(file_path: &str) -> io::Result> { +fn extract_called_functions( + file_path: &str, + function_info: &[FunctionInfo], +) -> io::Result> { let content = fs::read_to_string(file_path)?; let syntax = syn::parse_file(&content).expect("Failed to parse file"); - let mut visitor = FuzzTargetVisitor::default(); + let mut visitor = FuzzTargetVisitor::new(function_info.to_vec()); visitor.visit_file(&syntax); - // Remove duplicate items + // Remove duplicate items and sort by line number let set: HashSet<_> = visitor.called_functions.into_iter().collect(); - let result = set.into_iter().collect(); + let mut result: Vec<(String, usize)> = set.into_iter().collect(); + result.sort_by_key(|item| item.1); + result = post_process_called_functions(result); Ok(result) } +// Helper function to post process the called function vector +fn post_process_called_functions(items: Vec<(String, usize)>) -> Vec<(String, usize)> { + let mut stored_value: Option = None; + let mut result = Vec::new(); + + for (mut string_value, usize_value) in items { + if let Some(pos) = string_value.rfind("::") { + stored_value = Some(string_value[..pos].to_string()); + } else if let Some(stored) = &stored_value { + string_value = format!("{}::{}", stored, string_value); + } + + // Push the updated item into the result + result.push((string_value, usize_value)); + } + + result +} + // Base struct and syn:Visit implementation for traversing the function call tree #[derive(Default)] struct FuzzTargetVisitor { called_functions: Vec<(String, usize)>, + function_info: Vec, + variable_types: HashMap, +} + +impl FuzzTargetVisitor { + pub fn new(function_info: Vec) -> Self { + FuzzTargetVisitor { + called_functions: Vec::new(), + function_info, + variable_types: HashMap::new(), + } + } + + // Helper method to extract type of method call receiver + fn extract_receiver_type(&self, receiver: &Expr) -> Option { + match receiver { + // Variable or parameter call + Expr::Path(path_expr) => { + let variable_name = path_expr.path.segments.last()?.ident.to_string(); + self.variable_types.get(&variable_name).cloned() + } + + // Chained method call + Expr::MethodCall(method_call) => { + let receiver_type = self.extract_receiver_type(&method_call.receiver); + let method_name = method_call.method.to_string(); + let name = match receiver_type { + Some(receiver) => format!("{}::{}", receiver, method_name), + None => method_name.clone(), + }; + self.lookup_function_return_type(&name) + } + + _ => None, + } + } + + // Helper method to lookup function return type for reference + fn lookup_function_return_type(&self, method_name: &str) -> Option { + let function_map: HashMap = self.function_info.iter().map(|f| (f.name.clone(), f)).collect(); + + if let Some(function_info) = find_function(method_name, &function_map) { + return Some(function_info.return_type.clone()); + } + None + } + + // Try extracting the local variable name creation + fn extract_variable_name(&self, pat: &syn::Pat) -> Option { + if let syn::Pat::Ident(ident) = pat { + Some(ident.ident.to_string()) + } else { + None + } + } } impl<'ast> Visit<'ast> for FuzzTargetVisitor { @@ -161,8 +240,6 @@ impl<'ast> Visit<'ast> for FuzzTargetVisitor { for arg in &node.args { self.visit_expr(arg); } - - syn::visit::visit_expr_call(self, node); } // visit implementation method for handling echo method experssion @@ -171,21 +248,30 @@ impl<'ast> Visit<'ast> for FuzzTargetVisitor { let span = node.method.span().start(); let line_number = span.line; - if let Expr::Path(ExprPath { path, .. }) = &*node.receiver { - let receiver_name = path_to_string(&path); - let qualified_name = format!("{}::{}", receiver_name, method_name); - self.called_functions.push((qualified_name, line_number)); - } else { - let qualified_name = method_name; - self.called_functions.push((qualified_name, line_number)); - } + // Determine the fully qualified name + let receiver_type = self.extract_receiver_type(&node.receiver); + let qualified_name = match receiver_type { + Some(receiver) => format!("{}::{}", receiver, method_name), + None => method_name.clone(), + }; + + self.called_functions.push((qualified_name, line_number)); self.visit_expr(&node.receiver); for arg in &node.args { self.visit_expr(arg); } + } - syn::visit::visit_expr_method_call(self, node); + // visit implementation for local variables + fn visit_local(&mut self, local: &syn::Local) { + if let Some(init_expr) = &local.init { + if let Some(var_name) = self.extract_variable_name(&local.pat) { + if let Some(var_type) = self.extract_receiver_type(&init_expr.expr) { + self.variable_types.insert(var_name, var_type); + } + } + } } // General method ensure visiting all kinds of Expr that could call functions/methods @@ -352,8 +438,12 @@ fn build_call_tree( depth: usize, ) -> Option { let mut result = String::new(); + let indent = " ".repeat(depth + 1); + + if line_number == 0 { + line_number = -1; + } - // Only include functions/methods found in the project (determined from analysis result) if let Some(function_info) = find_function(function_name, function_map) { if visited.contains(&function_info.name) { return None; @@ -361,12 +451,6 @@ fn build_call_tree( visited.insert(function_info.name.clone()); - let indent = " ".repeat(depth + 1); - - if line_number == 0 { - line_number = -1; - } - // Insert the call tree line result.push_str(&format!( "{}{} {} linenumber={}\n", @@ -392,8 +476,12 @@ fn build_call_tree( } } } + } else { + result.push_str(&format!( + "{}{} {} linenumber={}\n", + indent, function_name.replace(" ", ""), call_path, line_number + )); } - if result.is_empty() { None } else { @@ -404,25 +492,27 @@ fn build_call_tree( // Search for the functions in the analysis result and exclude functions/methods not from the project fn find_function<'a>( function_name: &str, - function_map: &'a HashMap, + function_map: &'a HashMap, ) -> Option<&'a FunctionInfo> { + // Exact match if let Some(func) = function_map.get(function_name) { return Some(func); } - let simplified_name = function_name.split("::").last().unwrap_or(function_name); - let mut best_match: Option<&FunctionInfo> = None; - let mut best_match_length = 0; + // Match any key that ends with function_name + if let Some((_, func)) = function_map.iter().find(|(key, _)| key.ends_with(function_name)) { + return Some(func); + } - for func in function_map.values() { - if func.name.ends_with(simplified_name) { - let match_length = func.name.len(); - if match_length > best_match_length { - best_match = Some(func); - best_match_length = match_length; - } + // Split and check segments from the right side + let segments: Vec<&str> = function_name.split("::").collect(); + for i in 0..segments.len() { + let partial_name = segments[i..].join("::"); + if let Some(func) = function_map.get(&partial_name) { + return Some(func); } } - best_match + // No match found + None }