Skip to content

Commit

Permalink
chore(mypy): add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasrothenberger committed Jul 12, 2024
1 parent b9d588f commit b3f79bf
Show file tree
Hide file tree
Showing 82 changed files with 2,185 additions and 1,036 deletions.
28 changes: 14 additions & 14 deletions discopop_explorer/PEGraphX.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(self, type: EdgeType):
self.metadata_sink_ancestors = []
self.metadata_source_ancestors = []

def __str__(self):
def __str__(self) -> str:
return self.var_name if self.var_name is not None else str(self.etype)


Expand Down Expand Up @@ -233,13 +233,13 @@ def contains_line(self, other_line: str) -> bool:
return True
return False

def __str__(self):
def __str__(self) -> str:
return self.id

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return isinstance(other, Node) and other.file_id == self.file_id and other.node_id == self.node_id

def __hash__(self):
def __hash__(self) -> int:
return hash(self.id)

def get_parent_id(self, pet: PEGraphX) -> Optional[NodeID]:
Expand Down Expand Up @@ -722,7 +722,7 @@ def from_parsed_input(
print("\tAdded dependencies...")
return cls(g, reduction_vars, pos)

def map_static_and_dynamic_dependencies(self):
def map_static_and_dynamic_dependencies(self) -> None:
print("\tMapping static to dynamic dependencies...")
print("\t\tIdentifying mappings between static and dynamic memory regions...", end=" ")
mem_reg_mappings: Dict[MemoryRegion, Set[MemoryRegion]] = dict()
Expand Down Expand Up @@ -883,7 +883,7 @@ def calculateFunctionMetadata(
self.g.remove_edge(edge[0], edge[1], edge[2])
print("Cleaning dependencies II done.")

def calculateLoopMetadata(self):
def calculateLoopMetadata(self) -> None:
print("Calculating loop metadata")

# calculate loop indices
Expand All @@ -905,7 +905,7 @@ def calculateLoopMetadata(self):

print("Calculating loop metadata done.")

def show(self):
def show(self) -> None:
"""Plots the graph
:return:
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def all_nodes(self) -> List[Node]: ...
@overload
def all_nodes(self, type: Union[Type[NodeT], Tuple[Type[NodeT], ...]]) -> List[NodeT]: ...

def all_nodes(self, type=Node):
def all_nodes(self, type: Any = Node) -> List[NodeT]:
"""List of all nodes of specified type
:param type: type(s) of nodes
Expand Down Expand Up @@ -1057,7 +1057,7 @@ def subtree_of_type(self, root: Node) -> List[Node]: ...
@overload
def subtree_of_type(self, root: Node, type: Union[Type[NodeT], Tuple[Type[NodeT], ...]]) -> List[NodeT]: ...

def subtree_of_type(self, root, type=Node):
def subtree_of_type(self, root: Node, type: Any = Node) -> List[NodeT]:
"""Gets all nodes in subtree of specified type including root
:param root: root node
Expand All @@ -1074,7 +1074,7 @@ def subtree_of_type_rec(
self, root: Node, visited: Set[Node], type: Union[Type[NodeT], Tuple[Type[NodeT], ...]]
) -> List[NodeT]: ...

def subtree_of_type_rec(self, root, visited, type=Node):
def subtree_of_type_rec(self, root: Node, visited: Set[Node], type: Any = Node) -> List[NodeT]:
"""recursive helper function for subtree_of_type"""
# check if root is of type target
res = []
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def is_scalar_val(self, allVars: List[Variable], var: str) -> bool:
"""
for x in allVars:
if x.name == var:
return not (x.type.endswith("**") or x.type.startswith("ARRAY" or x.type.startswith("[")))
return not (x.type.endswith("**") or x.type.startswith("ARRAY") or x.type.startswith("["))
else:
return False
raise ValueError("allVars must not be empty.")
Expand Down Expand Up @@ -1557,14 +1557,14 @@ def get_reduction_sign(self, line: str, name: str) -> str:
return rv["operation"]
return ""

def dump_to_pickled_json(self):
def dump_to_pickled_json(self) -> str:
"""Encodes and returns the entire Object into a pickled json string.
The encoded string can be reconstructed into an object by using:
jsonpickle.decode(json_str)
:return: encoded string
"""
return jsonpickle.encode(self)
return cast(str, jsonpickle.encode(self))

def check_reachability(self, target: Node, source: Node, edge_types: List[EdgeType]) -> bool:
"""check if target is reachable from source via edges of types edge_type.
Expand Down Expand Up @@ -1667,7 +1667,7 @@ def check_reachability_and_get_path_nodes(
queue.append((cast(CUNode, self.node_at(e[0])), tmp_path))
return False, []

def dump_to_gephi_file(self, name="pet.gexf"):
def dump_to_gephi_file(self, name: str = "pet.gexf") -> None:
"""Note: Destroys the PETGraph!"""
# replace node data with label
for node_id in self.g.nodes:
Expand Down
2 changes: 1 addition & 1 deletion discopop_explorer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def parse_args() -> ExplorerArguments:
)


def main():
def main() -> None:
arguments = parse_args()
setup_logger(arguments)
run(arguments)
Expand Down
6 changes: 3 additions & 3 deletions discopop_explorer/discopop_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ class ExplorerArguments(GeneralArguments):
microbench_file: Optional[str]
load_existing_doall_and_reduction_patterns: bool

def __post_init__(self):
def __post_init__(self) -> None:
self.__validate()

def __validate(self):
def __validate(self) -> None:
"""Validate the arguments passed to the discopop_explorer, e.g check if given files exist"""
validation_failure = False

Expand Down Expand Up @@ -114,7 +114,7 @@ def __run(
hotspot_functions: Optional[Dict[HotspotType, List[Tuple[int, int, HotspotNodeType, str]]]] = None,
load_existing_doall_and_reduction_patterns: bool = False,
) -> DetectionResult:
pet = PEGraphX.from_parsed_input(*parse_inputs(cu_xml, dep_file, reduction_file, file_mapping))
pet = PEGraphX.from_parsed_input(*parse_inputs(cu_xml, dep_file, reduction_file, file_mapping)) # type: ignore
print("PET CREATION FINISHED.")
# pet.show()
# TODO add visualization
Expand Down
2 changes: 1 addition & 1 deletion discopop_explorer/generate_Data_CUInst.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def cu_instantiation_input_cpp(pet: PEGraphX, output_file: str) -> None:
def wrapper(cu_xml: str, dep_file: str, loop_counter_file: str, reduction_file: str, output_file: str) -> None:
"""Wrapper to generate the Data_CUInst.txt file, required for the generation of CUInstResult.txt"""
# 1. generate PET Graph
pet = PEGraphX.from_parsed_input(*parse_inputs(cu_xml, dep_file, loop_counter_file, reduction_file))
pet = PEGraphX.from_parsed_input(*parse_inputs(cu_xml, dep_file, loop_counter_file, reduction_file)) # type: ignore
# 2. Generate Data_CUInst.txt
cu_instantiation_input_cpp(pet, output_file)

Expand Down
4 changes: 2 additions & 2 deletions discopop_explorer/json_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# directory for details.

from json import JSONEncoder
from typing import Dict, Any
from typing import Dict, Any, List
from discopop_explorer.pattern_detectors.PatternBase import PatternBase
from discopop_library.discopop_optimizer.classes.context.Update import Update
from discopop_library.discopop_optimizer.classes.types.DataAccessType import WriteDataAccess
Expand Down Expand Up @@ -37,7 +37,7 @@ def filter_members(d: Dict[Any, Any]) -> Dict[Any, Any]:
class PatternBaseSerializer(JSONEncoder):
"""Json Encoder for Pattern Info"""

def default(self, o):
def default(self, o: Any) -> Any:
try:
iterable = iter(o)
except TypeError:
Expand Down
8 changes: 4 additions & 4 deletions discopop_explorer/parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
# the 3-Clause BSD License. See the LICENSE file in the package base
# directory for details.

from .PEGraphX import Node, NodeID, PEGraphX
from typing import List, Optional, Set
from .PEGraphX import FunctionNode, Node, NodeID, PEGraphX
from typing import Any, List, Optional, Set, Tuple

global_pet: Optional[PEGraphX] = None


def pet_function_metadata_initialize_worker(pet):
def pet_function_metadata_initialize_worker(pet: PEGraphX) -> None:
global global_pet
global_pet = pet


def pet_function_metadata_parse_func(func_node):
def pet_function_metadata_parse_func(func_node: FunctionNode) -> Tuple[NodeID, Any, set[NodeID]]:
if global_pet is None:
raise ValueError("global_pet is None!")

Expand Down
27 changes: 16 additions & 11 deletions discopop_explorer/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from collections import defaultdict
from dataclasses import dataclass
from os.path import abspath, dirname
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

from lxml import objectify # type:ignore
from lxml.objectify import ObjectifiedElement # type: ignore

# Map to record which line belongs to read set of nodes. LID -> NodeIds
readlineToCUIdMap = defaultdict(set) # type: ignore
Expand Down Expand Up @@ -46,7 +47,7 @@ class LoopData(object):
maximum_iteration_count: int


def __parse_xml_input(xml_fd):
def __parse_xml_input(xml_fd: TextIOWrapper) -> Dict[str, ObjectifiedElement]:
xml_content = ""
for line in xml_fd.readlines():
if not (line.rstrip().endswith("</Nodes>") or line.rstrip().endswith("<Nodes>")):
Expand All @@ -68,12 +69,13 @@ def __parse_xml_input(xml_fd):
# entry exists already! merge the two entries
pass
else:
tmp = node.get("id")
cu_dict[node.get("id")] = node

return cu_dict


def __map_dummy_nodes(cu_dict):
def __map_dummy_nodes(cu_dict: Dict[str, ObjectifiedElement]) -> Dict[str, ObjectifiedElement]:
dummy_node_args_to_id_map = defaultdict(list)
func_node_args_to_id_map = dict()
dummy_to_func_ids_map = dict()
Expand Down Expand Up @@ -208,12 +210,14 @@ def __parse_dep_file(dep_fd: TextIOWrapper, output_path: str) -> Tuple[List[Depe
return dependencies_list, loop_data_list


def parse_inputs(cu_file, dependencies, reduction_file, file_mapping):
def parse_inputs(
cu_file: str, dependencies_file_path: str, reduction_file: str, file_mapping: str
) -> Tuple[Dict[str, ObjectifiedElement], List[DependenceItem], Dict[str, LoopData], Optional[List[Dict[str, str]]]]:
with open(cu_file) as f:
cu_dict = __parse_xml_input(f)
cu_dict = __map_dummy_nodes(cu_dict)

with open(dependencies) as f:
with open(dependencies_file_path) as f:
dependencies, loop_info = __parse_dep_file(f, dirname(abspath(cu_file)))

loop_data = {loop.line_id: loop for loop in loop_info}
Expand Down Expand Up @@ -245,7 +249,7 @@ def parse_inputs(cu_file, dependencies, reduction_file, file_mapping):
return cu_dict, dependencies, loop_data, reduction_vars


def is_reduction(reduction_line, fmap_lines, file_mapping):
def is_reduction(reduction_line: str, fmap_lines: List[str], file_mapping: str) -> bool:
rex = re.compile("FileID : ([0-9]*) Loop Line Number : [0-9]* Reduction Line Number : ([0-9]*) ")
if not rex:
return False
Expand All @@ -265,7 +269,7 @@ def is_reduction(reduction_line, fmap_lines, file_mapping):
return possible_reduction(file_line, src_lines)


def possible_reduction(line, src_lines):
def possible_reduction(line: int, src_lines: List[str]) -> bool:
assert line > 0 and line <= len(src_lines), "invalid src line"
src_line = src_lines[line - 1]
while not ";" in src_line:
Expand Down Expand Up @@ -299,7 +303,7 @@ def possible_reduction(line, src_lines):
return True


def get_filepath(file_id, fmap_lines, file_mapping):
def get_filepath(file_id: int, fmap_lines: List[str], file_mapping: str) -> str:
assert file_id > 0 and file_id <= len(fmap_lines), "invalid file id"
line = fmap_lines[file_id - 1]
tokens = line.split(sep="\t")
Expand All @@ -309,7 +313,7 @@ def get_filepath(file_id, fmap_lines, file_mapping):
return tokens[1]


def get_enclosed_str(data):
def get_enclosed_str(data: str) -> str:
num_open_brackets = 1
for i in range(0, len(data)):
if data[i] == "[":
Expand All @@ -318,10 +322,11 @@ def get_enclosed_str(data):
num_open_brackets = num_open_brackets - 1
if num_open_brackets == 0:
return data[0:i]
raise ValueError("No enclosed str found!")


def find_array_indices(array_name, src_line):
indices = []
def find_array_indices(array_name: str, src_line: str) -> List[str]:
indices: List[str] = []
uses = list(re.finditer(array_name, src_line))
for use in uses:
if src_line[use.end()] == "[":
Expand Down
37 changes: 19 additions & 18 deletions discopop_explorer/pattern_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from alive_progress import alive_bar # type: ignore

from discopop_explorer.pattern_detectors.combined_gpu_patterns.classes.Aliases import VarName
from discopop_explorer.pattern_detectors.task_parallelism.task_parallelism_detector import (
build_preprocessed_graph_and_run_detection as detect_tp,
)
Expand Down Expand Up @@ -194,20 +195,20 @@ def detect_patterns(

def load_existing_doall_and_reduction_patterns(
self,
project_path,
cu_dict,
dependencies,
loop_data,
reduction_vars,
file_mapping,
cu_inst_result_file,
llvm_cxxfilt_path,
discopop_build_path,
enable_patterns,
enable_task_pattern,
enable_detection_of_scheduling_clauses,
hotspots,
):
project_path: str,
cu_dict: str,
dependencies: str,
loop_data: str,
reduction_vars: str,
file_mapping: Optional[str],
cu_inst_result_file: Optional[str],
llvm_cxxfilt_path: Optional[str],
discopop_build_path: Optional[str],
enable_patterns: str,
enable_task_pattern: bool,
enable_detection_of_scheduling_clauses: bool,
hotspots: Optional[Dict[HotspotType, List[Tuple[int, int, HotspotNodeType, str]]]],
) -> DetectionResult:
"""skips the pattern discovery on the CU graph and loads a pre-existing pattern file"""
self.__merge(False, True)
self.pet.map_static_and_dynamic_dependencies()
Expand Down Expand Up @@ -235,12 +236,12 @@ def load_existing_doall_and_reduction_patterns(
print("PATTERNS:")
print(pattern_contents)

def __get_var_obj_from_name(name):
return Variable(type="", name=name, defLine="", accessMode="", sizeInByte="0")
def __get_var_obj_from_name(name: VarName) -> Variable:
return Variable(type="", name=name, defLine="", accessMode="", sizeInByte=0)

def __get_red_var_obj_from_name(name):
def __get_red_var_obj_from_name(name: str) -> Variable:
split_name = name.split(":")
v = Variable(type="", name=split_name[0], defLine="", accessMode="", sizeInByte="0")
v = Variable(type="", name=VarName(split_name[0]), defLine="", accessMode="", sizeInByte=0)
v.operation = split_name[1]
return v

Expand Down
2 changes: 1 addition & 1 deletion discopop_explorer/pattern_detectors/PatternBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, node: Node):
self.end_line = node.end_position()
self.applicable_pattern = True

def to_json(self):
def to_json(self) -> str:
dic = self.__dict__
keys = [k for k in dic.keys()]
for key in keys:
Expand Down
Loading

0 comments on commit b3f79bf

Please sign in to comment.