-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
for training on the single video - not sure if it worked, but it does not matter now, as approach has changed
- Loading branch information
1 parent
eb0d4b0
commit 758817d
Showing
1 changed file
with
53 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import numpy as np | ||
import gym | ||
import SVE | ||
|
||
from keras.models import Sequential | ||
from keras.layers import Dense, Activation, Flatten | ||
from keras.optimizers import Adam | ||
|
||
from rl.agents.dqn import DQNAgent | ||
from rl.policy import BoltzmannQPolicy | ||
from rl.memory import SequentialMemory | ||
|
||
|
||
ENV_NAME = 'SingleVideoEnv-v0' | ||
|
||
|
||
# Get the environment and extract the number of actions. | ||
env = gym.make(ENV_NAME) | ||
np.random.seed(123) | ||
env.seed(123) | ||
nb_actions = env.action_space.n | ||
|
||
# Next, we build a very simple model. | ||
model = Sequential() | ||
model.add(Flatten(input_shape=(1,) + env.observation_space.shape)) | ||
model.add(Dense(512)) | ||
model.add(Activation('relu')) | ||
model.add(Dense(512)) | ||
model.add(Activation('relu')) | ||
model.add(Dense(512)) | ||
model.add(Activation('relu')) | ||
model.add(Dense(nb_actions)) | ||
model.add(Activation('linear')) | ||
print(model.summary()) | ||
|
||
# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and | ||
# even the metrics! | ||
memory = SequentialMemory(limit=50000, window_length=1) | ||
policy = BoltzmannQPolicy() | ||
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10, | ||
target_model_update=1e-2, policy=policy) | ||
dqn.compile(Adam(lr=1e-3), metrics=['mae']) | ||
|
||
# Okay, now it's time to learn something! We visualize the training here for show, but this | ||
# slows down training quite a lot. You can always safely abort the training prematurely using | ||
# Ctrl + C. | ||
dqn.fit(env, nb_steps=25310, visualize=False, verbose=2) | ||
|
||
# After training is done, we save the final weights. | ||
dqn.save_weights('dqn_{}_weights.h5f'.format(ENV_NAME), overwrite=True) | ||
|
||
# Finally, evaluate our algorithm for 5 episodes. | ||
dqn.test(env, nb_episodes=5, visualize=False) |