Skip to content

Commit

Permalink
Add unit tests for continuous action spaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
psc-g committed Jul 15, 2024
1 parent f857a4e commit f9c8560
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/python/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride]
if self.continuous:
action = tuple(action)
if len(action) != 3:
raise ValueError("Actions must have 3-dimensions.")
raise error.Error("Actions must have 3-dimensions.")

r, theta, fire = action
reward += self.ale.actContinuous(r, theta, fire)
Expand Down
29 changes: 29 additions & 0 deletions tests/python/test_atari_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,35 @@ def test_gym_action_space(tetris_env):
assert tetris_env.action_space.n == 18


@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True)
def test_continuous_action_space(tetris_env):
assert isinstance(tetris_env.action_space, gymnasium.spaces.Box)
assert len(tetris_env.action_space.shape) == 1
assert tetris_env.action_space.shape[0] == 3
np.testing.assert_array_equal(tetris_env.action_space.low, np.array([0., -1., 0.]))
np.testing.assert_array_equal(tetris_env.action_space.high, np.array([1., 1., 1.]))


@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True)
def test_continuous_action_sample(tetris_env):
tetris_env.reset(seed=0)
for _ in range(100):
tetris_env.step(tetris_env.action_space.sample())


@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True)
def test_continuous_step_with_correct_dimensions(tetris_env):
tetris_env.reset(seed=0)
tetris_env.step([0., -0.5, 0.5])


@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True)
def test_continuous_step_fails_with_wrong_dimensions(tetris_env):
tetris_env.reset(seed=0)
with pytest.raises(gymnasium.error.Error):
tetris_env.step([0., 1.])


def test_gym_reset_with_infos(tetris_env):
pack = tetris_env.reset(seed=0)

Expand Down

0 comments on commit f9c8560

Please sign in to comment.