-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
61 lines (48 loc) · 1.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from actor_critic import Agent
from utils import plot_learning_curve
from my_env import ENV
import dill
# https://github.com/philtabor/Youtube-Code-Repository/tree/master/ReinforcementLearning/PolicyGradient/actor_critic/tensorflow2
env = ENV()
agent = Agent(alpha=1e-5, n_actions=env.n_action)
n_games = 500
def learning(agent,env,n_games):
filename = 'cartpole_demo.png'
figure = 'plots/' + filename
best_score = env.reward_range[0]
score_history = []
load_checkpoint = False
# agent.actor_critic.build(input_shape=(None, 31))
# agent.load_models()
# agent.actor_critic.summary()
if load_checkpoint:
agent.actor_critic.build(input_shape=(None, 31))
agent.load_models()
for i in range(n_games):
observation = env.reset()
done = False
score = 0
time = 0
while not done:
action = agent.choose_action(observation)
observation_, reward, done, info = env.step(action, time)
score += reward
if not load_checkpoint:
agent.learn(observation, reward, observation_, done)
observation = observation_
time += 1
score_history.append(score)
avg_score = np.mean(score_history[-100:])
if avg_score > best_score:
best_score = avg_score
with open('best_agent.pkl', 'wb') as f:
dill.dump(agent, f) # pickle no || dill
if not load_checkpoint:
agent.save_models()
print('episode ', i, f'score {score}', f'avg_score {avg_score}')
if not load_checkpoint:
x = [i + 1 for i in range(n_games)]
plot_learning_curve(x, score_history, figure)
if __name__ == '__main__':
learning(agent, env, n_games)