Skip to content

Commit

Permalink
improving reloading of the agents
Browse files Browse the repository at this point in the history
  • Loading branch information
BDonnot committed Jul 6, 2020
1 parent eb19e42 commit 81b95c8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion l2rpn_baselines/DuelQSimple/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def train(env,
sizes = [800, 800, 800, 494, 494, 494] # sizes of each hidden layers
kwargs_archi = {'observation_size': observation_size,
'sizes': sizes,
'activs': ["relu" for _ in range(sizes)], # all relu activation function
'activs': ["relu" for _ in sizes], # all relu activation function
"list_attr_obs": li_attr_obs_X}
# select some part of the action
Expand Down
12 changes: 6 additions & 6 deletions l2rpn_baselines/SAC/SAC_NN.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
from tensorflow.keras.models import load_model, Sequential, Model
import tensorflow.keras.optimizers as tfko
from tensorflow.keras.layers import Activation, Dense
from tensorflow.keras.layers import Input, Concatenate

Expand Down Expand Up @@ -245,11 +244,12 @@ def load_network(self, path, name=None, ext="h5"):
We load all the models using the keras "load_model" function.
"""
path_model, path_target_model, path_modelQ, path_modelQ2, path_policy = self._get_path_model(path, name)
self.model_value = load_model('{}.{}'.format(path_model, ext))
self.model_value_target = load_model('{}.{}'.format(path_target_model, ext))
self.model_Q = load_model('{}.{}'.format(path_modelQ, ext))
self.model_Q2 = load_model('{}.{}'.format(path_modelQ2, ext))
self.model_policy = load_model('{}.{}'.format(path_policy, ext))
self.construct_q_network()
self.model_value.load_weights('{}.{}'.format(path_model, ext))
self.model_value_target.load_weights('{}.{}'.format(path_target_model, ext))
self.model_Q.load_weights('{}.{}'.format(path_modelQ, ext))
self.model_Q2.load_weights('{}.{}'.format(path_modelQ2, ext))
self.model_policy.load_weights('{}.{}'.format(path_policy, ext))
if self.verbose:
print("Succesfully loaded network.")

Expand Down
9 changes: 5 additions & 4 deletions l2rpn_baselines/utils/BaseDeepQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import tensorflow as tf
import tensorflow.keras.optimizers as tfko

from tensorflow.keras.models import load_model

from l2rpn_baselines.utils.TrainingParam import TrainingParam


Expand Down Expand Up @@ -204,11 +202,14 @@ def load_network(self, path, name=None, ext="h5"):
The file extension (by default h5)
"""
path_model, path_target_model = self.get_path_model(path, name)
self._model = load_model('{}.{}'.format(path_model, ext), custom_objects=self._custom_objects)
# fix for issue https://github.com/keras-team/keras/issues/7440
self.construct_q_network()

self._model.load_weights('{}.{}'.format(path_model, ext))

with warnings.catch_warnings():
warnings.filterwarnings("ignore")
self._target_model = load_model('{}.{}'.format(path_target_model, ext), custom_objects=self._custom_objects)
self._target_model.load_weights('{}.{}'.format(path_target_model, ext))
if self.verbose:
print("Succesfully loaded network.")

Expand Down

0 comments on commit 81b95c8

Please sign in to comment.