Skip to content

Commit

Permalink
updated tests and examples for new Evaluate wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Apr 15, 2024
1 parent f2eee15 commit 8258337
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 26 deletions.
54 changes: 29 additions & 25 deletions examples/q_learning_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from mpcrl import Agent, LearnableParameter, LearnableParametersDict, LstdQLearningAgent
from mpcrl.optim import NetwonMethod
from mpcrl.util.control import dlqr
from mpcrl.wrappers.agents import Log, RecordUpdates
from mpcrl.wrappers.agents import Evaluate, Log, RecordUpdates
from mpcrl.wrappers.envs import MonitorEpisodes

# first, create classes for environment and mpc controller
Expand Down Expand Up @@ -182,48 +182,52 @@ def _generate_rollout(n):

if __name__ == "__main__":
# now, let's create the instances of such classes
seed = 69
mpc = LinearMpc()
learnable_pars = LearnableParametersDict[cs.SX](
(
LearnableParameter(name, val.shape, val, sym=mpc.parameters[name])
for name, val in mpc.learnable_pars_init.items()
)
)
agent = Log( # type: ignore[var-annotated]
RecordUpdates(
LstdQLearningAgent(
mpc=mpc,
learnable_parameters=learnable_pars,
discount_factor=mpc.discount_factor,
update_strategy=1,
optimizer=NetwonMethod(learning_rate=5e-2),
hessian_type="approx",
record_td_errors=True,
remove_bounds_on_initial_action=True,
)
eval_env = MonitorEpisodes(TimeLimit(LtiSystem(), 100))
agent = Evaluate(
Log(
RecordUpdates(
LstdQLearningAgent(
mpc=mpc,
learnable_parameters=learnable_pars,
discount_factor=mpc.discount_factor,
update_strategy=1,
optimizer=NetwonMethod(learning_rate=5e-2),
hessian_type="approx",
record_td_errors=True,
remove_bounds_on_initial_action=True,
)
),
level=logging.DEBUG,
log_frequencies={"on_episode_end": 1},
),
level=logging.DEBUG,
log_frequencies={"on_episode_end": 1},
eval_env=eval_env,
hook="on_episode_end",
frequency=3,
n_eval_episodes=5,
eval_immediately=True,
seed=seed,
)

# before training, let's create a nominal non-learning agent which will be used to
# generate expert rollout data. This data will then be used to train the off-policy
# q-learning agent.
seed = 69
env_factory = lambda: MonitorEpisodes(TimeLimit(LtiSystem(), 100))
generate_rollout = get_rollout_generator(env_factory, seed)

# finally, we can launch the training
n_rollouts = 100
eval_returns = agent.train_offpolicy(
episode_rollouts=(generate_rollout(n) for n in range(n_rollouts)),
seed=seed,
eval_frequency=10,
eval_env_factory=env_factory,
eval_kwargs={
"episodes": 5, # every 10 rollouts, evaluate the agent on 5 episodes
},
n_rollouts = 10
agent.train_offpolicy(
episode_rollouts=(generate_rollout(n) for n in range(n_rollouts)), seed=seed
)
eval_returns = np.asarray(agent.eval_returns)

# plot the results
import matplotlib.pyplot as plt
Expand Down
36 changes: 35 additions & 1 deletion tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os
import unittest
from functools import lru_cache
from itertools import combinations
from itertools import combinations, product
from random import random
from typing import Any
from unittest.mock import Mock, call

Expand All @@ -12,6 +13,7 @@
from csnlp import Nlp, scaling
from csnlp.multistart import StackedMultistartNlp
from csnlp.wrappers import Mpc, NlpScaling
from parameterized import parameterized

from mpcrl import (
Agent,
Expand Down Expand Up @@ -353,6 +355,38 @@ def side_effect():
np.testing.assert_equal(pars_actual, pars)


class TestEvaluate(unittest.TestCase):
@parameterized.expand(product((False, True), (False, True)))
def test_evaluate__evaluates_with_correct_frequency(
self, eval_immediately: bool, fix_seed: bool
):
frequency = 10
repeats = 2
returns = [object() for _ in range(repeats + eval_immediately)]
returns_iter = iter(returns)
agent = mk_agent()
agent.evaluate = Mock(side_effect=lambda *_, **__: next(returns_iter))
env = SimpleEnv()
wrapped = wrappers_agents.Evaluate(
agent,
env,
"on_episode_end",
frequency=frequency,
eval_immediately=eval_immediately,
fix_seed=fix_seed,
)

n_calls = frequency * repeats
n_calls += int(frequency / 2) # adds some spurious calls
[agent.on_episode_end(env, i, random()) for i in range(n_calls)]

self.assertEqual(agent.evaluate.call_count, repeats + eval_immediately)
self.assertListEqual(wrapped.eval_returns, returns)
if fix_seed:
seeds = (call.args[3] for call in agent.evaluate.call_args_list)
self.assertEqual(len(set(seeds)), 1)


class TestMonitorEpisodesAndInfos(unittest.TestCase):
def test__compact_dicts(self):
act = compact_dicts(
Expand Down

0 comments on commit 8258337

Please sign in to comment.