forked from RUCAIBox/RecBole
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_hyper.py
131 lines (117 loc) · 4.46 KB
/
run_hyper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding: utf-8 -*-
# @Time : 2020/7/24 15:57
# @Author : Shanlei Mu
# @Email : [email protected]
# @File : run_hyper.py
# UPDATE:
# @Time : 2020/8/20 21:17, 2020/8/29, 2022/7/13, 2022/7/18
# @Author : Zihan Lin, Yupeng Hou, Gaowei Zhang, Lei Wang
import argparse
import os
import numpy as np
from recbole.trainer import HyperTuning
from recbole.quick_start import objective_function
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
import math
def hyperopt_tune(args):
# plz set algo='exhaustive' to use exhaustive search, in this case, max_evals is auto set
# in other case, max_evals needs to be set manually
config_file_list = (
args.config_files.strip().split(" ") if args.config_files else None
)
hp = HyperTuning(
objective_function,
algo="exhaustive",
early_stop=10,
max_evals=100,
params_file=args.params_file,
fixed_config_file_list=config_file_list,
display_file=args.display_file,
)
hp.run()
hp.export_result(output_file=args.output_file)
print("best params: ", hp.best_params)
print("best result: ")
print(hp.params2result[hp.params2str(hp.best_params)])
def ray_tune(args):
config_file_list = (
args.config_files.strip().split(" ") if args.config_files else None
)
config_file_list = (
[os.path.join(os.getcwd(), file) for file in config_file_list]
if args.config_files
else None
)
params_file = (
os.path.join(os.getcwd(), args.params_file) if args.params_file else None
)
ray.init()
tune.register_trainable("train_func", objective_function)
config = {}
with open(params_file, "r") as fp:
for line in fp:
para_list = line.strip().split(" ")
if len(para_list) < 3:
continue
para_name, para_type, para_value = (
para_list[0],
para_list[1],
"".join(para_list[2:]),
)
if para_type == "choice":
para_value = eval(para_value)
config[para_name] = tune.choice(para_value)
elif para_type == "uniform":
low, high = para_value.strip().split(",")
config[para_name] = tune.uniform(float(low), float(high))
elif para_type == "quniform":
low, high, q = para_value.strip().split(",")
config[para_name] = tune.quniform(float(low), float(high), float(q))
elif para_type == "loguniform":
low, high = para_value.strip().split(",")
config[para_name] = tune.loguniform(
math.exp(float(low)), math.exp(float(high))
)
else:
raise ValueError("Illegal param type [{}]".format(para_type))
# choose different schedulers to use different tuning optimization algorithms
# For details, please refer to Ray's official website https://docs.ray.io
scheduler = ASHAScheduler(
metric="recall@10", mode="max", max_t=10, grace_period=1, reduction_factor=2
)
local_dir = "./ray_log"
result = tune.run(
tune.with_parameters(objective_function, config_file_list=config_file_list),
config=config,
num_samples=5,
log_to_file=args.output_file,
scheduler=scheduler,
local_dir=local_dir,
resources_per_trial={"gpu": 1},
)
best_trial = result.get_best_trial("recall@10", "max", "last")
print("best params: ", best_trial.config)
print("best result: ", best_trial.last_result)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_files", type=str, default=None, help="fixed config files"
)
parser.add_argument("--params_file", type=str, default=None, help="parameters file")
parser.add_argument(
"--output_file", type=str, default="hyper_example.result", help="output file"
)
parser.add_argument(
"--display_file", type=str, default=None, help="visualization file"
)
parser.add_argument("--tool", type=str, default="Hyperopt", help="tuning tool")
args, _ = parser.parse_known_args()
if args.tool == "Hyperopt":
hyperopt_tune(args)
elif args.tool == "Ray":
ray_tune(args)
else:
raise ValueError(f"The tool [{args.tool}] should in ['Hyperopt', 'Ray']")