diff --git a/src/python/env.py b/src/python/env.py index cd3ae580a..37f1c161e 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -270,7 +270,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] if self.continuous: action = tuple(action) if len(action) != 3: - raise ValueError("Actions must have 3-dimensions.") + raise error.Error("Actions must have 3-dimensions.") r, theta, fire = action reward += self.ale.actContinuous(r, theta, fire) diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 1521cc01b..e50388bb4 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -329,6 +329,35 @@ def test_gym_action_space(tetris_env): assert tetris_env.action_space.n == 18 +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_action_space(tetris_env): + assert isinstance(tetris_env.action_space, gymnasium.spaces.Box) + assert len(tetris_env.action_space.shape) == 1 + assert tetris_env.action_space.shape[0] == 3 + np.testing.assert_array_equal(tetris_env.action_space.low, np.array([0., -1., 0.])) + np.testing.assert_array_equal(tetris_env.action_space.high, np.array([1., 1., 1.])) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_action_sample(tetris_env): + tetris_env.reset(seed=0) + for _ in range(100): + tetris_env.step(tetris_env.action_space.sample()) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_step_with_correct_dimensions(tetris_env): + tetris_env.reset(seed=0) + tetris_env.step([0., -0.5, 0.5]) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_step_fails_with_wrong_dimensions(tetris_env): + tetris_env.reset(seed=0) + with pytest.raises(gymnasium.error.Error): + tetris_env.step([0., 1.]) + + def test_gym_reset_with_infos(tetris_env): pack = tetris_env.reset(seed=0)