Skip to content

Commit

Permalink
Update analyzer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yilin-bao authored Dec 8, 2023
1 parent ba5a59b commit 09e7a7d
Showing 1 changed file with 81 additions and 218 deletions.
299 changes: 81 additions & 218 deletions analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,39 +200,6 @@ def remove_var_flag(self):
self.var_flag = '.'.join(arr[:-1])
else:
self.var_flag = None

# Getter method for Variable Module Dictionary
@property
def var_module_dict(self):
return self._var_module_dict

# Hidden setter method for Variable Module Dictionary
@var_module_dict.setter
def _var_module_dict(self, value):
# Custom logic can be added in the setter
self._var_module_dict = value

# Getter method for Variable Weight Dictionary
@property
def var_weight_dict(self):
return self._var_weight_dict

# Hidden setter method for Variable Weight Dictionary
@var_weight_dict.setter
def _var_weight_dict(self, value):
# Custom logic can be added in the setter
self._var_weight_dict = value

# Getter method for Variable Bias Dictionary
@property
def var_bias_dict(self):
return self._var_bias_dict

# Hidden setter method for Variable Bias Dictionary
@var_bias_dict.setter
def _var_bias_dict(self, value):
# Custom logic can be added in the setter
self._var_bias_dict = value

def start_analyze_module(self, module:nn.Module):
'''
Expand Down Expand Up @@ -276,10 +243,10 @@ def analyze_module_by_cases(self, var_name, module):
pass
self.update_module_flag(module)
self.update_var_flag(var_name)
list_module_names = []
var_module_layer = {}
for name, layer in module.named_children():
var_whole_name, module_name = self.analyze_module(name, layer)
list_module_names.append(var_whole_name)
var_module_layer[var_whole_name] = module_name
self.remove_module_flag()
self.remove_var_flag()
# Either if current module is a pyTorch in-built module
Expand All @@ -288,13 +255,20 @@ def analyze_module_by_cases(self, var_name, module):
if self.is_torch_module(module):
self.analyze_inbuild_module()
else:
self.analyze_defined_module()
self.analyze_defined_module(module, var_module_layer)
return 0

def analyze_inbuild_module(self):
return 0

def analyze_defined_module(self):
def analyze_defined_module(self, module, var_module_layer):
module_code = inspect.getsource(type(module))
module_ast = ast.parse(module_code)
# [var_module_layer] is the variable-module dictionary of current layer
analyzer = ModuleAstAnalyzer(var_module_layer)
analyzer.visit(module_ast)
result = analyzer.module_map
print("Results:", result)
return 0

# def analyze_inbuild_module(self, var_name, var_whole_name, module_name, module):
Expand Down Expand Up @@ -410,9 +384,11 @@ def print_current_layer_information(self, var_whole_name, module_name, depth=2):
#============================================================

class ModuleAstAnalyzer(ast.NodeVisitor):
def __init__(self, module_list):
def __init__(self, var_module_dict):
# Parent stack
self.parent_stack = []
self.module_list:dict = module_list
# In [ModuleAstAnalyzer], [var_module_dict] is just the current analyzed layer
self.var_module_dict:dict = var_module_dict
self.module_map = []

self.temp_var_ids = []
Expand All @@ -425,55 +401,38 @@ def __init__(self, module_list):
self.out_flag = None
self.out_dict = {}

#----------------------------------------
#---------Generic visit enhance----------
#----------------------------------------

def generic_visit_with_parent_stack(self, node):
self.parent_stack.append(node)
self.generic_visit(node)
self.parent_stack.pop()

# Deal with functions and classes

def visit_FunctionDef(self, node):
# Only take look with how modules are called in nn.Module.forward
if node.name == 'forward':
for arg in node.args.args:
# print("arguments in forward", arg.arg)
if not arg.arg == 'self':
self.forward_var_list.append(arg.arg)
self.generic_visit(node)

def visit_ClassDef(self, node):
self.generic_visit(node)

# Inside self.forward
# analyze the nerual network structure

def visit_Name(self, node: Name) -> Any:
# if node.id in self.module_list:
# parent_type = [str(type(p)) for p in self.parent_stack]
if node in self.parent_stack:
return 0
if node.id in self.forward_var_list:
# parent_type = [str(type(p)) for p in self.parent_stack]
self.analyze_net_name(self.parent_stack, node)
self.generic_visit_with_parent_stack(node)

def visit_Call(self, node: Call) -> Any:
self.generic_visit_with_parent_stack(node)
#----------------------------------------
#----Boolean determination functions-----
#----------------------------------------

def visit_For(self, node: For) -> Any:
self.generic_visit_with_parent_stack(node)
def all_NameAttribute(self, targets):
return self.all_Name(targets) or self.all_Attribute(targets)

def visit_BinOp(self, node: BinOp) -> Any:
self.generic_visit_with_parent_stack(node)
def all_Name(self, targets):
for target in targets:
if not isinstance(target, ast.Name):
return False
return True

def visit_Attribute(self, node: Attribute) -> Any:
self.generic_visit_with_parent_stack(node)

def visit_Assign(self, node: Assign) -> Any:
self.generic_visit_with_parent_stack(node)

# Typer determination functions
def all_Attribute(self, targets):
for target in targets:
if not isinstance(target, ast.Attribute):
return False
return True

#----------------------------------------
#---------Generically find names---------
#----------------------------------------

def find_full_name(self, node):
if isinstance(node, ast.Name):
return node.id
Expand All @@ -488,9 +447,7 @@ def find_full_name(self, node):
return node.attr
else:
pass # This won't happen
# elif isinstance(node, ast.Call):
# return -'.'


def find_all_names(self, node_list):
ret = []
for node in node_list:
Expand All @@ -501,146 +458,52 @@ def find_all_names(self, node_list):
elif isinstance(node, ast.BinOp):
ret.append(astor.to_source(node))
return ret

#--------------------------------------------------
#--Analyze special functions (such as forward())---
#--------------------------------------------------

def all_NameAttribute(self, targets):
return self.all_Name(targets) or self.all_Attribute(targets)
def visit_FunctionDef(self, node):
# Only take look with how modules are called in nn.Module.forward
if node.name == 'forward':
for arg in node.args.args:
# print("arguments in forward", arg.arg)
if not arg.arg == 'self':
self.forward_var_list.append(arg.arg)
self.generic_visit(node)

#--------------------------------------------------
#----Rewrite some visting for history tracking-----
#--------------------------------------------------

def visit_Call(self, node: Call) -> Any:
self.generic_visit_with_parent_stack(node)

def all_Name(self, targets):
for target in targets:
if not isinstance(target, ast.Name):
return False
return True

def all_Attribute(self, targets):
for target in targets:
if not isinstance(target, ast.Attribute):
return False
return True
def visit_For(self, node: For) -> Any:
self.generic_visit_with_parent_stack(node)

def visit_BinOp(self, node: BinOp) -> Any:
self.generic_visit_with_parent_stack(node)

def visit_Attribute(self, node: Attribute) -> Any:
self.generic_visit_with_parent_stack(node)

def visit_Assign(self, node: Assign) -> Any:
self.generic_visit_with_parent_stack(node)

def visit_Tuple(self, node: Tuple) -> Any:
return self.generic_visit_with_parent_stack(node)

def analyze_net_name(self, parents, this:Name):
if len(parents) == 0:
return 0
# print(f"{Color.PURPLE}{parents[0]}{self.analyzed_source_codes}{Color.END}")
# if astor.to_source(parents[-1]) in self.analyzed_source_codes:
# return 0
if parents[0] in self.analyzed_source_codes:
#--------------------------------------------------
#-------------Core code of this class--------------
#--------------------------------------------------

def visit_Name(self, node: Name) -> Any:
# if node.id in self.var_module_dict:
# parent_type = [str(type(p)) for p in self.parent_stack]
if node in self.parent_stack:
return 0
# At here, we made the parents upside down
# so the above [0] is the future [-1]
parents = parents[::-1]
current_var = this.id
# out_flag = None
this_flag = this
# current_modules = self.module_list
if len(parents) == 1:
p = parents[0]
if isinstance(p, ast.Assign) and p.targets[0] == this:
return 0
parent_type = [str(type(p)) for p in parents]
# print(parent_type, this.id, astor.to_source(parents[0]))
elif len(parents) == 2:
p_0 = parents[0]
p_1 = parents[1]
if isinstance(p_0, ast.Call) and isinstance(p_1, ast.Assign):
if self.find_full_name(p_0.func):
if self.find_full_name(p_0.func) in list(self.module_list.keys()):
op = self.module_list[self.find_full_name(p_0.func)]
op_in = [self.find_full_name(a) for a in p_0.args]
for oi in op_in:
if oi in list(self.out_dict.keys()):
op_in.remove(oi)
op_in.append(f'{oi}.{self.out_dict[oi]}')
op_id = generate_id(id_list)
id_list.append(op_id)
op_out = [f'{self.find_full_name(a)}.{op_id}' for a in p_1.targets]
for oo in op_out:
if oo.split('.')[0] in list(self.out_dict.keys()):
self.out_dict[oo.split('.')[0]] = oo.split('.')[1]
tri_node = (op_in, op_out, op)
self.module_map.append(tri_node)
pass
elif isinstance(p_0, ast.Attribute) and isinstance(p_1, ast.Assign) and isinstance(p_1.targets[0], ast.Tuple) and p_0.value == this:
op = p_0.attr
op_out = []
target = p_1.targets[0]
self.out_flag = this.id
if isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name):
op_id = generate_id(id_list)
id_list.append(op_id)
op_out.append(f'{elt.id}.{op_id}')
# if not op == 'shape':
self.forward_param_list.append(elt.id)
op_id = generate_id(id_list)
id_list.append(op_id)
op_in = [f'{this.id}.{op_id}']
self.out_flag = f'{this.id}.{op_id}'
self.out_dict[this.id] = op_id
tri_node = (op_in, op_out, op)
self.module_map.append(tri_node)
# for target in p_1.targets:
# print(target)
# if isinstance(target, ast.Name):
# op_id = generate_id(id_list)
# op_in = self.find_full_name(p_0.value)
# op_out = f'{target.id}.{generate_id(op_id)}'
# id_list.append(op_id)
# if not target.id in self.forward_var_list: self.forward_var_list.append(target.id)
# tri_node = (op_in, op_out, op)
# self.module_map.append(tri_node)
# print(self.module_map)
else:
pass
parent_type = [str(type(p)) for p in parents]
# print(parent_type, this.id, astor.to_source(parents[-1]))
else:
for i, p in enumerate(parents):
if isinstance(p, ast.Call):
# print(self.find_full_name(p.func))
if self.find_full_name(p.func):
if self.find_full_name(p.func) in list(self.module_list.keys()):
op = self.module_list[self.find_full_name(p.func)]
if self.out_flag:
op_in = [f'{self.out_flag}']
else:
op_in = None
if i < len(parents)-1:
op_id = generate_id(id_list)
id_list.append(op_id)
op_out = op_id
self.out_flag = op_out
else:
self.out_flag = None
tri_node = (op_in, [op_out], op)
self.module_map.append(tri_node)
pass
elif isinstance(p, ast.Attribute):
if p.value == this_flag:
op = p.attr
if self.out_flag:
op_in = [f'{self.out_flag}']
else:
op_in = None
if i < len(parents)-1:
op_id = generate_id(id_list)
id_list.append(op_id)
op_out = op_id
self.out_flag = op_out
else:
self.out_flag = None
for oo in [op_out]:
if oo.split('.')[0] in list(self.out_dict.keys()):
self.out_dict[oo.split('.')[0]] = oo.split('.')[1]
tri_node = (op_in, [op_out], op)
self.module_map.append(tri_node)
this_flag = p
parent_type = [str(type(p)) for p in parents]
# print(parent_type, this.id, astor.to_source(parents[-1]))
self.analyzed_source_codes.add(parents[-1])
# print(f"{Color.GREEN}{self.forward_var_list}{Color.END}")
# print(f"{Color.GREEN}{self.forward_param_list}{Color.END}")
return 0
if node.id in self.forward_var_list:
# parent_type = [str(type(p)) for p in self.parent_stack]
self.analyze_net_name(self.parent_stack, node)
self.generic_visit_with_parent_stack(node)

0 comments on commit 09e7a7d

Please sign in to comment.