Skip to content

Commit

Permalink
updating tests after seeding changes
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Nov 21, 2023
1 parent cc5be42 commit b59e478
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 1,377 deletions.
4 changes: 2 additions & 2 deletions src/mpcrl/util/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
]


MAX_SEED = np.iinfo(np.uint32).max # 2**32 - 1
MAX_SEED = np.iinfo(np.uint32).max + 1


def mk_seed(rng: np.random.Generator) -> int:
Expand All @@ -28,6 +28,6 @@ def mk_seed(rng: np.random.Generator) -> int:
Returns
-------
int
A random integer in the range [0, 2**32 - 1]
A random integer in the range [0, 2**32)
"""
return int(rng.integers(MAX_SEED))
Binary file modified tests/data_test_examples_win32.mat
Binary file not shown.
5 changes: 3 additions & 2 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from mpcrl import exploration as E
from mpcrl import schedulers as S
from mpcrl.util.seeding import mk_seed

OPTS = {
"expand": True,
Expand Down Expand Up @@ -377,9 +378,9 @@ def test_evaluate__performs_correct_calls(self):
)

np.testing.assert_allclose(returns, rewards.reshape(-1, episode_length).sum(1))
seeds = np.random.SeedSequence(seed).generate_state(episodes)
rng = np.random.default_rng(seed)
env.reset.assert_has_calls(
[call(seed=seeds[i], options=reset_options) for i in range(episodes)]
[call(seed=mk_seed(rng), options=reset_options) for i in range(episodes)]
)
for mcall, u1, u2 in zip(env.step.mock_calls, actions1, actions2):
self.assertEqual(len(mcall.args), 1)
Expand Down
Loading

0 comments on commit b59e478

Please sign in to comment.