forked from KULeuven-MICAS/zigzag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
52 lines (47 loc) · 1.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from zigzag.classes.stages import *
import argparse
import re
# Parse the workload and accelerator arguments
parser = argparse.ArgumentParser(description="Setup zigzag-v2 inputs")
parser.add_argument('--model', metavar='path', required=True, help='module path to workload, e.g. inputs.examples.workloads.resnet18')
parser.add_argument('--mapping', metavar='path', required=True, help='path to mapping file, e.g., inputs.examples.mapping.tpu_like')
parser.add_argument('--accelerator', metavar='path', required=True, help='module path to the accelerator, e.g. inputs.examples.hardware.TPU_like')
args = parser.parse_args()
# Initialize the logger
import logging as _logging
_logging_level = _logging.INFO
_logging_format = '%(asctime)s - %(name)s.%(funcName)s +%(lineno)s - %(levelname)s - %(message)s'
_logging.basicConfig(level=_logging_level,
format=_logging_format)
hw_name = args.accelerator.split(".")[-1]
wl_name = re.split(r"/|\.", args.model)[-1]
if wl_name == 'onnx':
wl_name = re.split(r"/|\.", args.model)[-2]
experiment_id = f"{hw_name}-{wl_name}"
pkl_name = f'{experiment_id}-saved_list_of_cmes'
# Initialize the MainStage which will start execution.
# The first argument of this init is the list of stages that will be executed in sequence.
# The second argument of this init are the arguments required for these different stages.
mainstage = MainStage([
WorkloadParserStage,
AcceleratorParserStage,
SimpleSaveStage,
PickleSaveStage,
SumStage,
CompleteSaveStage,
WorkloadStage,
SpatialMappingGeneratorStage,
MinimalLatencyStage,
LomaStage,
CostModelStage,
],
accelerator=args.accelerator,
workload=args.model,
mapping=args.mapping,
dump_filename_pattern=f"outputs/{experiment_id}-layer_?.json",
pickle_filename=f"outputs/{pkl_name}.pickle",
loma_lpf_limit=6,
loma_show_progress_bar=True,
)
# Launch the MainStage
mainstage.run()