Skip to content

Commit

Permalink
reworked seeding
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Nov 21, 2023
1 parent e604866 commit cc5be42
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 35 deletions.
19 changes: 9 additions & 10 deletions src/mpcrl/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Collection, Iterable, Iterator, Sequence
from collections.abc import Collection, Iterable, Iterator
from typing import Any, Generic, Literal, Optional, TypeVar, Union

import casadi as cs
Expand All @@ -14,6 +14,7 @@
from ..core.callbacks import AgentCallbackMixin
from ..core.exploration import ExplorationStrategy, NoExploration
from ..util.named import Named
from ..util.seeding import RngType, mk_seed

SymType = TypeVar("SymType", cs.SX, cs.MX)
ActType: TypeAlias = Union[npt.ArrayLike, dict[str, npt.ArrayLike]]
Expand Down Expand Up @@ -137,9 +138,7 @@ def exploration(self) -> ExplorationStrategy:
"""Gets the exploration strategy used within this agent."""
return self._exploration

def reset(
self, seed: Union[None, int, Sequence[int], np.random.SeedSequence] = None
) -> None:
def reset(self, seed: RngType = None) -> None:
"""Resets the agent's internal variables and exploration's RNG."""
self._last_solution = None
self._last_action = None
Expand Down Expand Up @@ -323,7 +322,7 @@ def evaluate(
env: Env[ObsType, ActType],
episodes: int,
deterministic: bool = True,
seed: Union[None, int, Sequence[int]] = None,
seed: RngType = None,
raises: bool = True,
env_reset_options: Optional[dict[str, Any]] = None,
) -> npt.NDArray[np.floating]:
Expand All @@ -341,7 +340,7 @@ def evaluate(
Number of evaluation episodes.
deterministic : bool, optional
Whether the agent should act deterministically; by default, `True`.
seed : None, int or sequence of ints, optional
seed : None, int, array_like[ints], SeedSequence, BitGenerator, Generator
Agent's and each env's RNG seed.
raises : bool, optional
If `True`, when any of the MPC solver runs fails, or when an update fails,
Expand All @@ -359,13 +358,13 @@ def evaluate(
------
Raises if the MPC optimization solver fails and `warns_on_exception=False`.
"""
rng = np.random.default_rng(seed)
self.reset(rng)
returns = np.zeros(episodes)
self.on_validation_start(env)
seeds = map(int, np.random.SeedSequence(seed).generate_state(episodes))

for episode, current_seed in zip(range(episodes), seeds):
self.reset(current_seed)
state, _ = env.reset(seed=current_seed, options=env_reset_options)
for episode in range(episodes):
state, _ = env.reset(seed=mk_seed(rng), options=env_reset_options)
truncated, terminated, timestep = False, False, 0
self.on_episode_start(env, episode, state)

Expand Down
20 changes: 9 additions & 11 deletions src/mpcrl/agents/learning_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Collection, Sequence
from collections.abc import Collection
from typing import Any, Callable, Generic, Optional, TypeVar, Union

import numpy as np
Expand All @@ -11,6 +11,7 @@
from ..core.exploration import ExplorationStrategy
from ..core.parameters import LearnableParametersDict
from ..core.update import UpdateStrategy
from ..util.seeding import RngType, mk_seed
from .agent import ActType, Agent, ObsType, SymType, _update_dicts

ExpType = TypeVar("ExpType")
Expand Down Expand Up @@ -91,9 +92,7 @@ def learnable_parameters(self) -> LearnableParametersDict[SymType]:
"""Gets the parameters of the MPC that can be learnt by the agent."""
return self._learnable_pars

def reset(
self, seed: Union[None, int, Sequence[int], np.random.SeedSequence] = None
) -> None:
def reset(self, seed: RngType = None) -> None:
"""Resets agent's internal variables, exploration and experience's RNG"""
super().reset(seed)
self.experience.reset(seed)
Expand All @@ -116,7 +115,7 @@ def train(
self,
env: Env[ObsType, ActType],
episodes: int,
seed: Union[None, int, Sequence[int]] = None,
seed: RngType = None,
raises: bool = True,
env_reset_options: Optional[dict[str, Any]] = None,
) -> npt.NDArray[np.floating]:
Expand All @@ -128,7 +127,7 @@ def train(
A gym environment where to train the agent in.
episodes : int
Number of training episodes.
seed : None, int or sequence of ints, optional
seed : None, int, array_like[ints], SeedSequence, BitGenerator, Generator
Agent's and each env's RNG seed.
raises : bool, optional
If `True`, when any of the MPC solver runs fails, or when an update fails,
Expand All @@ -150,15 +149,14 @@ def train(
UpdateError or UpdateWarning
Raises the error or the warning (depending on `raises`) if the update fails.
"""
rng = np.random.default_rng(seed)
self.reset(rng)
self._updates_enabled = True
self._raises = raises
returns = np.zeros(episodes, float)
self.on_training_start(env)
seeds = map(int, np.random.SeedSequence(seed).generate_state(episodes))

for episode, current_seed in zip(range(episodes), seeds):
self.reset(current_seed)
state, _ = env.reset(seed=current_seed, options=env_reset_options)
for episode in range(episodes):
state, _ = env.reset(seed=mk_seed(rng), options=env_reset_options)
self.on_episode_start(env, episode, state)
r = self.train_one_episode(env, episode, state, raises)
self.on_episode_end(env, episode, r)
Expand Down
12 changes: 6 additions & 6 deletions src/mpcrl/core/experience.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Iterator
from itertools import chain
from typing import Deque, Optional, TypeVar, Union

import numpy as np

from ..util.seeding import RngType

ExpType = TypeVar("ExpType")


Expand All @@ -18,7 +20,7 @@ def __init__(
maxlen: Optional[int] = None,
sample_size: Union[int, float] = 1,
include_latest: Union[int, float] = 0,
seed: Union[None, int, Sequence[int], np.random.SeedSequence] = None,
seed: RngType = None,
) -> None:
"""Instantiate the container for experience replay memory.
Expand All @@ -35,17 +37,15 @@ def __init__(
include_latest : int or float, optional
Size (or percentage of `sample_size`) dedicated to including the latest
experience transitions. By default, 0, i.e., no last item is included.
seed : None, int, sequence of ints or SeedSequence, optional
seed : None, int, array_like[ints], SeedSequence, BitGenerator, Generator
Seed for the random number generator. By default, `None`.
"""
super().__init__(iterable, maxlen=maxlen)
self.sample_size = sample_size
self.include_latest = include_latest
self.reset(seed)

def reset(
self, seed: Union[None, int, Sequence[int], np.random.SeedSequence] = None
) -> None:
def reset(self, seed: RngType = None) -> None:
"""Resets the sampling RNG."""
self.np_random = np.random.default_rng(seed)

Expand Down
14 changes: 6 additions & 8 deletions src/mpcrl/core/exploration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Literal, Optional, Union

import numpy as np
import numpy.typing as npt

from ..util.seeding import RngType
from .schedulers import NoScheduling, Scheduler


Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
self,
strength: Union[Scheduler[npt.NDArray[np.floating]], npt.NDArray[np.floating]],
hook: Literal["on_update", "on_episode_end", "on_timestep_end"] = "on_update",
seed: Union[None, int, Sequence[int], np.random.SeedSequence] = None,
seed: RngType = None,
) -> None:
"""Initializes the greedy exploration strategy.
Expand All @@ -88,7 +88,7 @@ def __init__(
- `on_timestep_end` steps the exploration after each env's timestep.
By default, 'on_update' is selected.
seed : None, int, sequence of ints or SeedSequence, optional
seed : None, int, array_like[ints], SeedSequence, BitGenerator, Generator
Number to seed the RNG engine used for randomizing the exploration. By
default, `None`.
"""
Expand All @@ -112,9 +112,7 @@ def strength(self) -> npt.NDArray[np.floating]:
"""Gets the current strength of the exploration strategy."""
return self.strength_scheduler.value

def reset(
self, seed: Union[None, int, Sequence[int], np.random.SeedSequence] = None
) -> None:
def reset(self, seed: RngType = None) -> None:
"""Resets the exploration RNG."""
self.np_random = np.random.default_rng(seed)

Expand Down Expand Up @@ -166,7 +164,7 @@ def __init__(
epsilon: Union[Scheduler[float], float],
strength: Union[Scheduler[npt.NDArray[np.floating]], npt.NDArray[np.floating]],
hook: Literal["on_update", "on_episode_end", "on_timestep_end"] = "on_update",
seed: Union[None, int, Sequence[int], np.random.SeedSequence] = None,
seed: RngType = None,
) -> None:
"""Initializes the epsilon-greedy exploration strategy.
Expand All @@ -189,7 +187,7 @@ def __init__(
- `on_timestep_end` steps the exploration after each env's timestep.
By default, 'on_update' is selected.
seed : None, int, sequence of ints or SeedSequence, optional
seed : None, int, array_like[ints], SeedSequence, BitGenerator, Generator
Number to seed the RNG engine used for randomizing the exploration. By
default, `None`.
"""
Expand Down
33 changes: 33 additions & 0 deletions src/mpcrl/util/seeding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections.abc import Sequence
from typing import Union

import numpy as np
from typing_extensions import TypeAlias

RngType: TypeAlias = Union[
None,
int,
Sequence[int],
np.random.SeedSequence,
np.random.BitGenerator,
np.random.Generator,
]


MAX_SEED = np.iinfo(np.uint32).max # 2**32 - 1


def mk_seed(rng: np.random.Generator) -> int:
"""Generates a random seed.
Parameters
----------
rng : np.random.Generator
RNG generator
Returns
-------
int
A random integer in the range [0, 2**32 - 1]
"""
return int(rng.integers(MAX_SEED))

0 comments on commit cc5be42

Please sign in to comment.