Skip to content

Commit

Permalink
Update hash function
Browse files Browse the repository at this point in the history
  • Loading branch information
yilin-bao authored Dec 8, 2023
1 parent 0db7ba1 commit 632cd65
Showing 1 changed file with 72 additions and 17 deletions.
89 changes: 72 additions & 17 deletions analyzer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import ast, astor
import hashlib
from ast import Assign, Attribute, BinOp, Call, For, Name, Return, Tuple, mod
from gettext import find
import inspect
from re import L
from typing import Any
from matplotlib.pylab import pareto
from numpy import isin, var
from sympy import false
from torch import rand
Expand Down Expand Up @@ -67,7 +69,12 @@ class Color:
# ANSI escape code to reset text attributes to default
END = '\033[0m'

id_list = []

def hash_code(code):
sha256_hash = hashlib.sha256()
sha256_hash.update(code.encode('utf-8'))
hashed_code = sha256_hash.hexdigest()
return hashed_code

#============================================================
#======================[Hash] functions======================
Expand Down Expand Up @@ -384,21 +391,19 @@ def print_current_layer_information(self, var_whole_name, module_name, depth=2):

class ModuleAstAnalyzer(ast.NodeVisitor):
def __init__(self, var_module_dict):
# Parent stack
# Parent stack, write in all the parents node visited before
self.parent_stack = []
# In [ModuleAstAnalyzer], [var_module_dict] is just the current analyzed layer
self.var_module_dict:dict = var_module_dict
self.module_map = []

self.temp_var_ids = []
self.forward_input = []
self.current_var = ""

self.forward_var_list = []
# forward_tensor_list: tensor, matrix, vector usd in deep learning
# forward_param_list: int, float, dimension variables, other variables
self.forward_tensor_list = []
self.forward_param_list = []
self.analyzed_source_codes = set()
self.out_flag = None
# Store the current version (by uuid) of each tensor/variable
self.out_dict = {}
self.hash_var_dict = {}

#----------------------------------------
#---------Generic visit enhance----------
Expand Down Expand Up @@ -468,7 +473,8 @@ def visit_FunctionDef(self, node):
for arg in node.args.args:
# print("arguments in forward", arg.arg)
if not arg.arg == 'self':
self.forward_var_list.append(arg.arg)
self.forward_tensor_list.append(arg.arg)
hashlib.sha256(arg.arg)
self.generic_visit(node)

#--------------------------------------------------
Expand Down Expand Up @@ -498,12 +504,61 @@ def visit_Tuple(self, node: Tuple) -> Any:
#--------------------------------------------------

def visit_Name(self, node: Name) -> Any:
# if node.id in self.var_module_dict:
# parent_type = [str(type(p)) for p in self.parent_stack]
if node in self.parent_stack:
if node.id in self.forward_tensor_list or node.id in self.forward_param_list:
self.analyze_net_name(self.parent_stack, node)
self.generic_visit_with_parent_stack(node)

def analyze_net_name(self, parents, this:Name):
if len(parents) == 0:
return 0
if node.id in self.forward_var_list:
# parent_type = [str(type(p)) for p in self.parent_stack]
# self.analyze_net_name(self.parent_stack, node)
# self.forward_tensor_list = []
# self.forward_param_list = []
# self.out_dict = {}
# self.hash_var_dict = {}
parents = parents[::-1]
if len(parents) == 1:
self.special_case_length_one(parents[0], this)
elif len(parents) == 2:
self.special_case_length_two(parents[0], parents[1], this)
else:
self.print_parents_and_code()
return 0

def special_case_length_one(self, parent, this:Name):
if isinstance(parent, ast.Attribute):
# [<ast.Attribute object at 0x12aa0f790>] nn.Module
# [<ast.Attribute object at 0x12aa0d4b0>] self.revised
pass
self.generic_visit_with_parent_stack(node)
elif isinstance(parent, ast.Assign) and parent.targets[0] == this:
# [<ast.Assign object at 0x130b28760>] x = self.embedding_layer(x)
# [<ast.Assign object at 0x130b28940>] x = self.transformer(x)
# [<ast.Assign object at 0x130b29a50>] x = self.post_transformer_ln(x)
# [<ast.Assign object at 0x130b29f30>] x = self.cls_layer(x)
# What we find is the variable [x] on the left side of assign, so ignore
pass
elif isinstance(parent, ast.Assign) and not parent.targets[0] == this:
self.print_parents_and_code()
else:
self.print_parents_and_code()

def special_case_length_two(self, parent, grandparent, this:Name):
# self.forward_tensor_list = []
# self.forward_param_list = []
# self.out_dict = {}
# self.hash_var_dict = {}
if isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Call):
print(parent.func)
self.print_parents_and_code()
elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call):
self.print_parents_and_code()
elif isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Attribute):
print(parent.attr)
self.print_parents_and_code()
elif isinstance(grandparent, ast.For) and isinstance(parent, ast.Assign):
self.print_parents_and_code()
else:
self.print_parents_and_code()

def print_parents_and_code(self):
if len(self.parent_stack) >= 1:
print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end=' ')

0 comments on commit 632cd65

Please sign in to comment.