Skip to content

Commit

Permalink
Only keep functions referenced by fptr if fptr is used
Browse files Browse the repository at this point in the history
Up until this commit, any function referenced by a fptr would be unconditionally excluded
from DCE. Now we're actually checking if the function pointer is even used. If it isn't, don't
bother excluding the referenced function!

To achieve this, the tool now treats calls and function pointer assignments equally. Instead of
resolve_calls resolving into calls[] and resolve_fptrs[] resolving into fptrs[], both now resolve
into referenced_functions[]. traverse_calls() has also been renamed to traverse_functions() since
it isn't restricted to calls anymore.

Signed-off-by: Patrick Pedersen <[email protected]>
  • Loading branch information
CTXz committed Jun 24, 2024
1 parent 5bef024 commit 8db03a6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 34 deletions.
31 changes: 11 additions & 20 deletions src/stm8dce/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def run(
for function in functions:
function.resolve_constants(constants)

# Resolve constants accessed by initializers
# Resolve functions and constants accessed by initializers
debug.pdbg()
debug.pdbg("Resolving functions and constants accessed by initializers")
debug.pseperator()
Expand Down Expand Up @@ -213,7 +213,7 @@ def run(
debug.pdbg()
debug.pdbg(f"Traversing entry function: {entry_label}")
debug.pseperator()
keep_functions += [entry_function] + asm_analysis.traverse_calls(
keep_functions += [entry_function] + asm_analysis.traverse_functions(
functions, entry_function
)
elif modules:
Expand Down Expand Up @@ -242,25 +242,12 @@ def run(
f"Traversing function {function.name} referenced by module {entry_module.name}"
)
debug.pseperator()
keep_functions += [function] + asm_analysis.traverse_calls(
keep_functions += [function] + asm_analysis.traverse_functions(
functions, function
)
else:
raise ValueError(f"Error: Entry label not found: {entry_label}")

# Keep functions assigned to a function pointer
for func in functions:
for function_pointer in func.fptrs:
if function_pointer not in keep_functions:
debug.pdbg()
debug.pdbg(
f"Traversing function assigned to function pointer: {function_pointer.name}"
)
debug.pseperator()
keep_functions += [function_pointer] + asm_analysis.traverse_calls(
functions, function_pointer
)

# Keep interrupt handlers and all of their traversed calls
# but exclude unused IRQ handlers if opted by the user
interrupt_handlers = asm_analysis.interrupt_handlers(functions)
Expand All @@ -270,7 +257,9 @@ def run(
debug.pdbg()
debug.pdbg(f"Traversing IRQ handler: {handler.name}")
debug.pseperator()
keep_functions += [handler] + asm_analysis.traverse_calls(functions, handler)
keep_functions += [handler] + asm_analysis.traverse_functions(
functions, handler
)

# Keep functions accessed by initializers
for initializer in initializers:
Expand All @@ -284,7 +273,7 @@ def run(
f"Traversing function {function_pointer.name} accessed by initializer"
)
debug.pseperator()
keep_functions += [function_pointer] + asm_analysis.traverse_calls(
keep_functions += [function_pointer] + asm_analysis.traverse_functions(
functions, function_pointer
)

Expand Down Expand Up @@ -313,7 +302,7 @@ def run(
debug.pdbg()
debug.pdbg(f"Traversing excluded function: {name}")
debug.pseperator()
keep_functions += [excluded_function] + asm_analysis.traverse_calls(
keep_functions += [excluded_function] + asm_analysis.traverse_functions(
functions, excluded_function
)

Expand All @@ -333,7 +322,9 @@ def run(
f"Traversing function {ref.name} referenced by module {module.name}"
)
debug.pseperator()
keep_functions += [ref] + asm_analysis.traverse_calls(functions, ref)
keep_functions += [ref] + asm_analysis.traverse_functions(
functions, ref
)
elif isinstance(ref, asm_analysis.Constant) and ref not in keep_constants:
keep_constants.append(ref)

Expand Down
27 changes: 13 additions & 14 deletions src/stm8dce/asm_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,11 @@ class Function:
long_read_labels_str (list): List of long read labels.
Generated Attributes:
calls (list): List of resolved functions called by the function (See resolve_calls).
function_references (list): List of functions referenced by the function (See resolve_calls & resolve_fptrs).
external_calls (list): List of external functions (in rel & lib files) called by the function.
constants (list): List of resolved constants read by the function (See resolve_constants).
external_constants (list): List of external constants (in rel & lib files) read by the function.
global_defs (list): List of resolved global definitions used by the function (See resolve_globals).
fptrs (list): List of resolved function pointers assigned by the function (See resolve_fptrs).
isr_def (IntDef): Resolved interrupt definition associated with the function (See resolve_isr).
empty (bool): Indicates if the function is empty.
Expand All @@ -117,12 +116,11 @@ def __init__(self, path, start_line_number, name):
self.calls_str = []
self.long_read_labels_str = []

self.calls = []
self.function_references = []
self.external_calls = []
self.constants = []
self.external_constants = []
self.global_defs = []
self.fptrs = []
self.isr_def = None
self.empty = True

Expand All @@ -140,14 +138,15 @@ def print(self):
print(f"End line: {self.end_line_number}")
print(f"Calls: {self.calls_str}")
print(f"Long read labels: {self.long_read_labels_str}")
print(f"Resolved calls: {[call.name for call in self.calls]}")
print(
f"Resolved function references: {[call.name for call in self.function_references]}"
)
print(f"External calls: {self.external_calls}")
print(f"Resolved constants: {[const.name for const in self.constants]}")
print(f"External constants: {self.external_constants}")
print(
f"Resolved global definitions: {[glob.name for glob in self.global_defs]}"
)
print(f"Resolved function pointers: {[fptr.name for fptr in self.fptrs]}")
print(f"IRQ Handler: {self.isr_def}")
print(f"Empty: {self.empty}")

Expand Down Expand Up @@ -209,7 +208,7 @@ def resolve_calls(self, functions):
for func in funcs:
print(f"In file {func.path}:{func.start_line_number}")
exit(1)
self.calls.append(funcs[0])
self.function_references.append(funcs[0])
debug.pdbg(
f"Function {self.name} in {self.path}:{self.start_line_number} calls function {funcs[0].name} in {funcs[0].path}:{funcs[0].start_line_number}"
)
Expand All @@ -222,7 +221,7 @@ def resolve_calls(self, functions):
f"Error: Multiple static definitions for function {func} in {func.path}"
)
exit(1)
self.calls.append(func)
self.function_references.append(func)
debug.pdbg(
f"Function {self.name} in {self.path}:{self.start_line_number} calls static function {func.name} in {func.path}:{func.start_line_number}"
)
Expand All @@ -238,7 +237,7 @@ def resolve_fptrs(self, functions):
for long_read_label in self.long_read_labels_str:
for func in functions:
if func.name == long_read_label:
self.fptrs.append(func)
self.function_references.append(func)
debug.pdbg(
f"Function {self.name} in {self.path}:{self.start_line_number} assigns function pointer to {func.name} in {func.path}:{func.start_line_number}"
)
Expand Down Expand Up @@ -568,9 +567,9 @@ def constant_by_filename_name(constants, filename, name):
return ret


def traverse_calls(functions, top):
def traverse_functions(functions, top):
"""
Traverse all calls made by a function and return a list of all traversed functions.
Traverse all functions referenced by a function and return a list of all traversed functions.
Args:
functions (list): List of Function objects.
Expand All @@ -583,11 +582,11 @@ def traverse_calls(functions, top):

ret = []

for call in top.calls:
if call == top:
for function in top.function_references:
if function == top:
continue

ret += [call] + traverse_calls(functions, call)
ret += [function] + traverse_functions(functions, function)

debug.pdbg(f"Traversing out {top.name} in {top.path}:{top.start_line_number}")

Expand Down

0 comments on commit 8db03a6

Please sign in to comment.