-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
86 lines (73 loc) · 2.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# @Filename: utils.py
# @Author: Ashutosh Tiwari
# @Email: [email protected]
# @Time: 4/16/22 8:50 PM
import os
from constants import *
from stable_baselines3.common.evaluation import evaluate_policy
import optuna
import time
from enum import Enum
from matplotlib import pyplot as plt
class GameState(Enum):
GUILE = STATE_GUILE
ZANGIEF = STATE_ZANGIEF
DAHLISM = STATE_DAHLISM
EHDONA = STATE_EHDONA
CHUNLI = STATE_CHUNLI
BLANKA = STATE_BLANKA
KEN = STATE_KEN
RYU = STATE_RYU
def record_model_playing(env, model, render=False):
obs = env.reset()
iteration = 0
done = False
total_reward = 0
for game in range(1):
while not done:
iteration += 1
if done:
obs = env.reset()
if render:
env.render()
time.sleep(0.01)
action, _ = model.predict(obs)
obs, reward, done, info = env.step(action)
# if reward != 0: print(reward)
total_reward += reward
print("iterations: ", iteration)
print("total reward: ", total_reward)
env.close()
return total_reward
def load_study(study_name, path):
if 'sqlite' not in path:
path = os.path.join('sqlite:///', path)
return optuna.load_study(study_name, path)
def evaluate_model_policy(env, model, n_eval_episodes=5):
"""
Evaluate a policy
:param env: (Gym Environment) The environment to evaluate the policy on
:param model: (BaseRLModel object) the policy, whose type depends on the environment.
:param n_eval_episodes: (int) number of episodes to evaluate the policy
:return: (float) Mean reward for the `n_eval_episodes` episodes
"""
score = evaluate_policy(model, env, n_eval_episodes=n_eval_episodes)[0]
return score
def plot_study(study, path=None):
plots = [optuna.visualization.matplotlib.plot_parallel_coordinate,
optuna.visualization.matplotlib.plot_contour,
optuna.visualization.matplotlib.plot_slice,
optuna.visualization.matplotlib.plot_param_importances,
optuna.visualization.matplotlib.plot_edf,
optuna.visualization.matplotlib.plot_optimization_history]
for plot in plots:
try:
_ = plot(study)
if path is None:
plt.show()
else:
p = os.path.join(path, plot.__name__ + '.png')
print("writing fig at ", str(p))
plt.savefig(p)
except Exception as e:
print("Error in plot: ", e)