Skip to content

Commit

Permalink
Add some printing method (for debug)
Browse files Browse the repository at this point in the history
  • Loading branch information
yilin-bao authored Dec 9, 2023
1 parent 04a19a4 commit e001d96
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions analyzer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast, astor
import enum
from ctypes import Array
import time
import hashlib
Expand Down Expand Up @@ -292,7 +293,7 @@ def analyze_defined_module(self, var_name, module_name, module, var_module_layer
# [var_module_layer] is the variable-module dictionary of current layer
analyzer = ModuleAstAnalyzer(var_module_layer, var_name, module_name)
analyzer.visit(module_ast)
# print("Results:", analyzer.module_map)
self.print_module_map(analyzer.module_map)
return 0

# def analyze_inbuild_module(self, var_name, var_whole_name, module_name, module):
Expand Down Expand Up @@ -403,6 +404,13 @@ def print_current_layer_information(self, var_whole_name, module_name, depth=2):
# else:
# print("This layer is not a final deconstructed layer, so there is no weight and bias")

def print_module_map(self, module_map, length=8):
print("================================")
print("Analyzer returns the module_map:")
print("================================")
for i, mm in enumerate(module_map):
print(f'Here is step {i} of the layer:', [a[:length] for a in mm[0]], [a[:length] for a in mm[1]], mm[2])

#============================================================
#======Ast static analyzer finds what happen in forward======
#============================================================
Expand Down Expand Up @@ -603,13 +611,14 @@ def special_case_length_two(self, parent, grandparent, this:Name):
# self.hash_var_dict = {}
if isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Call):
if grandparent.value == parent:
op = self.from_node_to_operation(parent.func)
op_name = self.from_node_to_operation(parent.func)
# print(self.var_module_dict)
args = self.find_full_name_array(parent.args)
targets = self.find_full_name_array(grandparent.targets)
print(args, targets)

self.print_parents_and_code(this)
self.update_module_name(args, targets, op_name)
else:
# self.print_parents_and_code(this)
pass
elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call):
# self.print_parents_and_code(this)
pass
Expand Down

0 comments on commit e001d96

Please sign in to comment.