From a8b723c1f982a863f83353ee3f760c612a7d70e4 Mon Sep 17 00:00:00 2001 From: psc-g Date: Fri, 31 May 2024 13:57:09 -0400 Subject: [PATCH 01/15] Changes to support continuous actions --- src/ale_interface.cpp | 10 +++ src/ale_interface.hpp | 7 ++ src/environment/ale_state.cpp | 90 +++++++++++++++++++++ src/environment/ale_state.hpp | 14 ++++ src/environment/stella_environment.cpp | 108 +++++++++++++++++++++++++ src/environment/stella_environment.hpp | 28 +++++++ src/python/__init__.pyi | 2 + src/python/ale_python_interface.hpp | 1 + src/python/env.py | 70 ++++++++++++++-- 9 files changed, 323 insertions(+), 7 deletions(-) diff --git a/src/ale_interface.cpp b/src/ale_interface.cpp index 83a146ea2..c6a79b3f5 100644 --- a/src/ale_interface.cpp +++ b/src/ale_interface.cpp @@ -263,6 +263,16 @@ reward_t ALEInterface::act(Action action) { return environment->act(action, PLAYER_B_NOOP); } +// Applies a continuous action to the game and returns the reward. It is the +// user's responsibility to check if the game has ended and reset +// when necessary - this method will keep pressing buttons on the +// game over screen. +reward_t ALEInterface::actContinuous(float r, float theta, float fire, + float continuous_action_threshold) { + return environment->actContinuous(r, theta, fire, 0.0, 0.0, 0.0, + continuous_action_threshold); +} + // Returns the vector of modes available for the current game. // This should be called only after the rom is loaded. ModeVect ALEInterface::getAvailableModes() const { diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 509f094d7..68ae9b164 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -87,6 +87,13 @@ class ALEInterface { // when necessary - this method will keep pressing buttons on the // game over screen. reward_t act(Action action); + + // Applies a continuous action to the game and returns the reward. It is the + // user's responsibility to check if the game has ended and reset + // when necessary - this method will keep pressing buttons on the + // game over screen. + reward_t actContinuous(float r, float theta, float fire, + float continuous_action_threshold = 0.5); // Indicates if the game has ended. bool game_over(bool with_truncation = true) const; diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 8f4eafe82..fd8749be5 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -13,6 +13,7 @@ #include "environment/ale_state.hpp" #include +#include #include #include #include @@ -287,6 +288,47 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, } } +void ALEState::applyActionPaddlesContinuous( + Event* event, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold) { + // Reset keys + resetKeys(event); + + // Convert polar coordinates to x/y position. + float a_x = player_a_r * cos(player_a_theta); + float a_y = player_a_r * sin(player_a_theta); + float b_x = player_b_r * cos(player_b_theta); + float b_y = player_b_r * sin(player_b_theta); + + // First compute whether we should increase or decrease the paddle position + // (for both left and right players) + int delta_a = 0; + if (a_x > continuous_action_threshold) { // Right action. + delta_a = -PADDLE_DELTA; + } else if (a_x < continuous_action_threshold) { // Left action. + delta_a = PADDLE_DELTA; + } + int delta_b = 0; + if (b_x > continuous_action_threshold) { // Right action. + delta_b = -PADDLE_DELTA; + } else if (b_x < continuous_action_threshold) { // Left action. + delta_b = PADDLE_DELTA; + } + + // Now update the paddle positions + updatePaddlePositions(event, delta_a, delta_b); + + // Now add the fire event + if (player_a_fire > continuous_action_threshold) { + event->set(Event::PaddleZeroFire, 1); + } + if (player_b_fire > continuous_action_threshold) { + event->set(Event::PaddleOneFire, 1); + } +} + void ALEState::pressSelect(Event* event) { resetKeys(event); event->set(Event::ConsoleSelect, 1); @@ -498,6 +540,54 @@ void ALEState::setActionJoysticks(Event* event, int player_a_action, } } + +void ALEState::setActionJoysticksContinuous( + Event* event, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold) { + // Reset keys + resetKeys(event); + + // Convert polar coordinates to x/y position. + float a_x = player_a_r * cos(player_a_theta); + float a_y = player_a_r * sin(player_a_theta); + float b_x = player_b_r * cos(player_b_theta); + float b_y = player_b_r * sin(player_b_theta); + + // Go through all possible events and add them if joystick position is there. + if (a_x < -continuous_action_threshold) { + event->set(Event::JoystickZeroLeft, 1); + } + if (a_x > continuous_action_threshold) { + event->set(Event::JoystickZeroRight, 1); + } + if (a_y < -continuous_action_threshold) { + event->set(Event::JoystickZeroDown, 1); + } + if (a_y > continuous_action_threshold) { + event->set(Event::JoystickZeroUp, 1); + } + if (player_a_fire > continuous_action_threshold) { + event->set(Event::JoystickZeroFire, 1); + } + if (b_x < -continuous_action_threshold) { + event->set(Event::JoystickOneLeft, 1); + } + if (b_x > continuous_action_threshold) { + event->set(Event::JoystickOneRight, 1); + } + if (b_y < -continuous_action_threshold) { + event->set(Event::JoystickOneDown, 1); + } + if (b_y > continuous_action_threshold) { + event->set(Event::JoystickOneUp, 1); + } + if (player_b_fire > continuous_action_threshold) { + event->set(Event::JoystickOneFire, 1); + } +} + /* *************************************************************************** Function resetKeys Unpresses all control-relevant keys diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index 0f1f5641f..6aad84683 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -64,9 +64,23 @@ class ALEState { * resistances. */ void applyActionPaddles(stella::Event* event_obj, int player_a_action, int player_b_action); + /** Applies paddle continuous actions. This actually modifies the game state + * by updating the paddle resistances. */ + void applyActionPaddlesContinuous( + stella::Event* event_obj, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_treshold = 0.5); /** Sets the joystick events. No effect until the emulator is run forward. */ void setActionJoysticks(stella::Event* event_obj, int player_a_action, int player_b_action); + /** Sets the joystick events for continuous actions. No effect until the + * emulator is run forward. */ + void setActionJoysticksContinuous( + stella::Event* event_obj, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold = 0.5); void incrementFrame(int steps = 1); diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index a54443826..6ee48b124 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -187,6 +187,51 @@ reward_t StellaEnvironment::act(Action player_a_action, return std::clamp(sum_rewards, m_reward_min, m_reward_max); } +reward_t StellaEnvironment::actContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold) { + // Total reward received as we repeat the action + reward_t sum_rewards = 0; + + Random& rng = getEnvironmentRNG(); + + // Apply the same action for a given number of times... note that act() will refuse to emulate + // past the terminal state + for (size_t i = 0; i < m_frame_skip; i++) { + // Stochastically drop actions, according to m_repeat_action_probability + if (rng.nextDouble() >= m_repeat_action_probability) { + m_player_a_r = player_a_r; + m_player_a_theta = player_a_theta; + m_player_a_fire = player_a_fire; + } + // @todo Possibly optimize by avoiding call to rand() when player B is "off" ? + if (rng.nextDouble() >= m_repeat_action_probability) { + m_player_b_r = player_b_r; + m_player_b_theta = player_b_theta; + m_player_b_fire = player_b_fire; + } + + // If so desired, request one frame's worth of sound (this does nothing if recording + // is not enabled) + m_osystem->sound().recordNextFrame(); + + // Render screen if we're displaying it + m_osystem->screen().render(); + + // Similarly record screen as needed + if (m_screen_exporter.get() != NULL) + m_screen_exporter->saveNext(m_screen); + + // Use the stored actions, which may or may not have changed this frame + sum_rewards += oneStepActContinuous(m_player_a_r, m_player_a_theta, m_player_a_fire, + m_player_b_r, m_player_b_theta, m_player_b_fire, + continuous_action_threshold); + } + + return sum_rewards; +} + /** This functions emulates a push on the reset button of the console */ void StellaEnvironment::softReset() { emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps); @@ -217,6 +262,31 @@ reward_t StellaEnvironment::oneStepAct(Action player_a_action, return m_settings->getReward(); } +/** Applies the given continuous actions (e.g. updating paddle positions when + * the paddle is used) and performs one simulation step in Stella. */ +reward_t StellaEnvironment::oneStepActContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold) { + // Once in a terminal state, refuse to go any further (special actions must be handled + // outside of this environment; in particular reset() should be called rather than passing + // RESET or SYSTEM_RESET. + if (isTerminal()) + return 0; + + // Convert illegal actions into NOOPs; actions such as reset are always legal + //noopIllegalActions(player_a_action, player_b_action); + + // Emulate in the emulator + emulateContinuous(player_a_r, player_a_theta, player_a_fire, + player_b_r, player_b_theta, player_b_fire, + continuous_action_threshold); + // Increment the number of frames seen so far + m_state.incrementFrame(); + + return m_settings->getReward(); +} + bool StellaEnvironment::isTerminal() const { return isGameTerminal() || isGameTruncated(); } @@ -287,6 +357,44 @@ void StellaEnvironment::emulate(Action player_a_action, Action player_b_action, processRAM(); } +void StellaEnvironment::emulateContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold, size_t num_steps) { + Event* event = m_osystem->event(); + + // Handle paddles separately: we have to manually update the paddle positions at each step + if (m_use_paddles) { + // Run emulator forward for 'num_steps' + for (size_t t = 0; t < num_steps; t++) { + // Update paddle position at every step + m_state.applyActionPaddlesContinuous( + event, + player_a_r, player_a_theta, player_a_fire, + player_b_r, player_b_theta, player_b_fire, + continuous_action_threshold); + + m_osystem->console().mediaSource().update(); + m_settings->step(m_osystem->console().system()); + } + } else { + // In joystick mode we only need to set the action events once + m_state.setActionJoysticksContinuous( + event, player_a_r, player_a_theta, player_a_fire, + player_b_r, player_b_theta, player_b_fire, + continuous_action_threshold); + + for (size_t t = 0; t < num_steps; t++) { + m_osystem->console().mediaSource().update(); + m_settings->step(m_osystem->console().system()); + } + } + + // Parse screen and RAM into their respective data structures + processScreen(); + processRAM(); +} + /** Accessor methods for the environment state. */ void StellaEnvironment::setState(const ALEState& state) { m_state = state; } diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index 803f1e02b..7fda5ef79 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -60,6 +60,18 @@ class StellaEnvironment { */ reward_t act(Action player_a_action, Action player_b_action); + /** Applies the given continuous actions (e.g. updating paddle positions when + * the paddle is used) and performs one simulation step in Stella. Returns the + * resultant reward. When frame skip is set to > 1, up the corresponding + * number of simulation steps are performed. Note that the post-act() frame + * number might not correspond to the pre-act() frame number plus the frame + * skip. + */ + reward_t actContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold); + /** This functions emulates a push on the reset button of the console */ void softReset(); @@ -121,9 +133,22 @@ class StellaEnvironment { /** This applies an action exactly one time step. Helper function to act(). */ reward_t oneStepAct(Action player_a_action, Action player_b_action); + /** This applies a continuous action exactly one time step. + * Helper function to actContinuous(). + */ + reward_t oneStepActContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold = 0.5); + + /** Actually emulates the emulator for a given number of steps. */ void emulate(Action player_a_action, Action player_b_action, size_t num_steps = 1); + void emulateContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold = 0.5, size_t num_steps = 1); /** Drops illegal actions, such as the fire button in skiing. Note that this is different * from the minimal set of actions. */ @@ -161,6 +186,9 @@ class StellaEnvironment { // The last actions taken by our players Action m_player_a_action, m_player_b_action; + float m_player_a_r, m_player_b_r; + float m_player_a_theta, m_player_b_theta; + float m_player_a_fire, m_player_b_fire; }; } // namespace ale diff --git a/src/python/__init__.pyi b/src/python/__init__.pyi index d1c619727..df671d9e1 100644 --- a/src/python/__init__.pyi +++ b/src/python/__init__.pyi @@ -103,6 +103,8 @@ class ALEInterface: def __init__(self) -> None: ... @overload def act(self, action: Action) -> int: ... + def actContinuous(self, r: float, theta: float, fire: float, + continuous_action_threshold: float) -> int: ... @overload def act(self, action: int) -> int: ... def cloneState(self, *, include_rng: bool = False) -> ALEState: ... diff --git a/src/python/ale_python_interface.hpp b/src/python/ale_python_interface.hpp index 7a6fb057e..6377f8db7 100644 --- a/src/python/ale_python_interface.hpp +++ b/src/python/ale_python_interface.hpp @@ -146,6 +146,7 @@ PYBIND11_MODULE(_ale_py, m) { ale::ALEPythonInterface::act) .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action)) & ale::ALEInterface::act) + .def("actContinuous", &ale::ALEPythonInterface::actContinuous) .def("game_over", &ale::ALEPythonInterface::game_over, py::kw_only(), py::arg("with_truncation") = py::bool_(true)) .def("game_truncated", &ale::ALEPythonInterface::game_truncated) .def("reset_game", &ale::ALEPythonInterface::reset_game) diff --git a/src/python/env.py b/src/python/env.py index abb918f23..439b2eced 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -41,6 +41,8 @@ def __init__( frameskip: tuple[int, int] | int = 4, repeat_action_probability: float = 0.25, full_action_space: bool = False, + continuous_actions: bool = False, + continuous_action_threshold: float = 0.5, max_num_frames_per_episode: int | None = None, render_mode: Literal["human", "rgb_array"] | None = None, ): @@ -58,6 +60,8 @@ def __init__( repeat_action_probability: int => Probability to repeat actions, see Machado et al., 2018 full_action_space: bool => Use full action space? + continuous_actions: bool => Use continuous actions? + continuous_action_threshold: float => threshold used for continuous actions. max_num_frames_per_episode: int => Max number of frame per epsiode. Once `max_num_frames_per_episode` is reached the episode is truncated. @@ -117,6 +121,8 @@ def __init__( frameskip=frameskip, repeat_action_probability=repeat_action_probability, full_action_space=full_action_space, + continuous_actions=continuous_actions, + continuous_action_threshold=continuous_action_threshold, max_num_frames_per_episode=max_num_frames_per_episode, render_mode=render_mode, ) @@ -149,13 +155,24 @@ def __init__( self.seed_game() self.load_game() - # initialize action space - self._action_set = ( - self.ale.getLegalActionSet() - if full_action_space - else self.ale.getMinimalActionSet() - ) - self.action_space = spaces.Discrete(len(self._action_set)) + self.continuous_actions = continuous_actions + self.continuous_action_threshold = continuous_action_threshold + if continuous_actions: + # We don't need action_set for continuous actions. + self._action_set = None + # Actions are radius, theta, and fire, where first two are the + # parameters of polar coordinates. + self._action_space = spaces.Box( + np.array([0, -1, 0]).astype(np.float32), + np.array([+1, +1, +1]).astype(np.float32), + ) # radius, theta, fire. First two are polar coordinates. + else: + self._action_set = ( + self.ale.getLegalActionSet() + if full_action_space + else self.ale.getMinimalActionSet() + ) + self._action_space = spaces.Discrete(len(self._action_set)) # initialize observation space if self._obs_type == "ram": @@ -253,6 +270,45 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] return self._get_obs(), reward, is_terminal, is_truncated, self._get_info() + def continuousStep( + self, + r: float, + theta: float, + fire: float, + ) -> Tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: + """Perform one agent step, i.e., repeats `action` frameskip # of steps. + + Args: + r: float => radius of polar coordinate action + theta: float => angle of polar coordinate action + fire: float => continuous fire action + + Returns: + Tuple[np.ndarray, float, bool, Dict[str, Any]] => + observation, reward, terminal, metadata + + Note: `metadata` contains the keys "lives" and "rgb" if + render_mode == 'rgb_array'. + """ + # If frameskip is a length 2 tuple then it's stochastic + # frameskip between [frameskip[0], frameskip[1]] uniformly. + if isinstance(self._frameskip, int): + frameskip = self._frameskip + elif isinstance(self._frameskip, tuple): + frameskip = self.np_random.integers(*self._frameskip) + else: + raise error.Error(f"Invalid frameskip type: {self._frameskip}") + + # Frameskip + reward = 0.0 + for _ in range(frameskip): + reward += self.ale.actContinuous(r, theta, fire, + self.continuous_action_threshold) + is_terminal = self.ale.game_over(with_truncation=False) + is_truncated = self.ale.game_truncated() + + return self._get_obs(), reward, is_terminal, is_truncated, self._get_info() + def render(self) -> np.ndarray | None: """ Render is not supported by ALE. We use a paradigm similar to From 4b442863ff188045d9b693aeba565295d1173d0d Mon Sep 17 00:00:00 2001 From: psc-g Date: Fri, 7 Jun 2024 09:50:00 -0400 Subject: [PATCH 02/15] Set continuous action threshold once in ALE constructor, not with every call to actContinuous. --- src/ale_interface.cpp | 6 ++---- src/ale_interface.hpp | 3 +-- src/environment/stella_environment.cpp | 20 +++++++++----------- src/environment/stella_environment.hpp | 9 ++++----- src/python/__init__.pyi | 3 +-- src/python/env.py | 5 +++-- 6 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/ale_interface.cpp b/src/ale_interface.cpp index c6a79b3f5..eae846fdd 100644 --- a/src/ale_interface.cpp +++ b/src/ale_interface.cpp @@ -267,10 +267,8 @@ reward_t ALEInterface::act(Action action) { // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the // game over screen. -reward_t ALEInterface::actContinuous(float r, float theta, float fire, - float continuous_action_threshold) { - return environment->actContinuous(r, theta, fire, 0.0, 0.0, 0.0, - continuous_action_threshold); +reward_t ALEInterface::actContinuous(float r, float theta, float fire) { + return environment->actContinuous(r, theta, fire, 0.0, 0.0, 0.0); } // Returns the vector of modes available for the current game. diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 68ae9b164..126f4e394 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -92,8 +92,7 @@ class ALEInterface { // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the // game over screen. - reward_t actContinuous(float r, float theta, float fire, - float continuous_action_threshold = 0.5); + reward_t actContinuous(float r, float theta, float fire); // Indicates if the game has ended. bool game_over(bool with_truncation = true) const; diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index 6ee48b124..5e57c8bf9 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -76,6 +76,8 @@ StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings) m_repeat_action_probability = m_osystem->settings().getFloat("repeat_action_probability"); + m_continuous_action_threshold = + m_osystem->settings().getFloat("continuous_action_threshold"); m_frame_skip = m_osystem->settings().getInt("frame_skip"); if (m_frame_skip < 1) { @@ -189,8 +191,7 @@ reward_t StellaEnvironment::act(Action player_a_action, reward_t StellaEnvironment::actContinuous( float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold) { + float player_b_r, float player_b_theta, float player_b_fire) { // Total reward received as we repeat the action reward_t sum_rewards = 0; @@ -225,8 +226,7 @@ reward_t StellaEnvironment::actContinuous( // Use the stored actions, which may or may not have changed this frame sum_rewards += oneStepActContinuous(m_player_a_r, m_player_a_theta, m_player_a_fire, - m_player_b_r, m_player_b_theta, m_player_b_fire, - continuous_action_threshold); + m_player_b_r, m_player_b_theta, m_player_b_fire); } return sum_rewards; @@ -266,8 +266,7 @@ reward_t StellaEnvironment::oneStepAct(Action player_a_action, * the paddle is used) and performs one simulation step in Stella. */ reward_t StellaEnvironment::oneStepActContinuous( float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold) { + float player_b_r, float player_b_theta, float player_b_fire) { // Once in a terminal state, refuse to go any further (special actions must be handled // outside of this environment; in particular reset() should be called rather than passing // RESET or SYSTEM_RESET. @@ -279,8 +278,7 @@ reward_t StellaEnvironment::oneStepActContinuous( // Emulate in the emulator emulateContinuous(player_a_r, player_a_theta, player_a_fire, - player_b_r, player_b_theta, player_b_fire, - continuous_action_threshold); + player_b_r, player_b_theta, player_b_fire); // Increment the number of frames seen so far m_state.incrementFrame(); @@ -360,7 +358,7 @@ void StellaEnvironment::emulate(Action player_a_action, Action player_b_action, void StellaEnvironment::emulateContinuous( float player_a_r, float player_a_theta, float player_a_fire, float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold, size_t num_steps) { + size_t num_steps) { Event* event = m_osystem->event(); // Handle paddles separately: we have to manually update the paddle positions at each step @@ -372,7 +370,7 @@ void StellaEnvironment::emulateContinuous( event, player_a_r, player_a_theta, player_a_fire, player_b_r, player_b_theta, player_b_fire, - continuous_action_threshold); + m_continuous_action_threshold); m_osystem->console().mediaSource().update(); m_settings->step(m_osystem->console().system()); @@ -382,7 +380,7 @@ void StellaEnvironment::emulateContinuous( m_state.setActionJoysticksContinuous( event, player_a_r, player_a_theta, player_a_fire, player_b_r, player_b_theta, player_b_fire, - continuous_action_threshold); + m_continuous_action_threshold); for (size_t t = 0; t < num_steps; t++) { m_osystem->console().mediaSource().update(); diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index 7fda5ef79..405fd2ca1 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -69,8 +69,7 @@ class StellaEnvironment { */ reward_t actContinuous( float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold); + float player_b_r, float player_b_theta, float player_b_fire); /** This functions emulates a push on the reset button of the console */ void softReset(); @@ -138,8 +137,7 @@ class StellaEnvironment { */ reward_t oneStepActContinuous( float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold = 0.5); + float player_b_r, float player_b_theta, float player_b_fire); /** Actually emulates the emulator for a given number of steps. */ @@ -148,7 +146,7 @@ class StellaEnvironment { void emulateContinuous( float player_a_r, float player_a_theta, float player_a_fire, float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold = 0.5, size_t num_steps = 1); + size_t num_steps = 1); /** Drops illegal actions, such as the fire button in skiing. Note that this is different * from the minimal set of actions. */ @@ -178,6 +176,7 @@ class StellaEnvironment { int m_max_num_frames_per_episode; // Maxmimum number of frames per episode size_t m_frame_skip; // How many frames to emulate per act() float m_repeat_action_probability; // Stochasticity of the environment + float m_continuous_action_threshold; // Continuous action threshold std::unique_ptr m_screen_exporter; // Automatic screen recorder int m_max_lives; // Maximum number of lives at the start of an episode. bool m_truncate_on_loss_of_life; // Whether to truncate episodes on loss of life. diff --git a/src/python/__init__.pyi b/src/python/__init__.pyi index df671d9e1..cbc3f7881 100644 --- a/src/python/__init__.pyi +++ b/src/python/__init__.pyi @@ -103,8 +103,7 @@ class ALEInterface: def __init__(self) -> None: ... @overload def act(self, action: Action) -> int: ... - def actContinuous(self, r: float, theta: float, fire: float, - continuous_action_threshold: float) -> int: ... + def actContinuous(self, r: float, theta: float, fire: float) -> int: ... @overload def act(self, action: int) -> int: ... def cloneState(self, *, include_rng: bool = False) -> ALEState: ... diff --git a/src/python/env.py b/src/python/env.py index 439b2eced..3d5b7860e 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -142,6 +142,8 @@ def __init__( self.ale.setLoggerMode(ale_py.LoggerMode.Error) # Config sticky action prob. self.ale.setFloat("repeat_action_probability", repeat_action_probability) + # Config continuous action threshold (if using continuous actions). + self.ale.setFloat("continuous_action_threshold", continuous_action_threshold) if max_num_frames_per_episode is not None: self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode) @@ -302,8 +304,7 @@ def continuousStep( # Frameskip reward = 0.0 for _ in range(frameskip): - reward += self.ale.actContinuous(r, theta, fire, - self.continuous_action_threshold) + reward += self.ale.actContinuous(r, theta, fire) is_terminal = self.ale.game_over(with_truncation=False) is_truncated = self.ale.game_truncated() From 87af1ec121bfc1e734462803352bb5758ee1b972 Mon Sep 17 00:00:00 2001 From: psc-g Date: Fri, 7 Jun 2024 10:08:01 -0400 Subject: [PATCH 03/15] Set default value for continuous action threshold. --- src/emucore/Settings.cxx | 1 + 1 file changed, 1 insertion(+) diff --git a/src/emucore/Settings.cxx b/src/emucore/Settings.cxx index 8e9193aae..bb11a8ca0 100644 --- a/src/emucore/Settings.cxx +++ b/src/emucore/Settings.cxx @@ -415,6 +415,7 @@ void Settings::setDefaultSettings() { boolSettings.insert(std::pair("send_rgb", false)); intSettings.insert(std::pair("frame_skip", 1)); floatSettings.insert(std::pair("repeat_action_probability", 0.25)); + floatSettings.insert(std::pair("continuous_action_threshold", 0.5)); stringSettings.insert(std::pair("rom_file", "")); // Whether to truncate an episode on loss of life. boolSettings.insert(std::pair("truncate_on_loss_of_life", false)); From 26ef98fb04d7a2cbd83787dc1fc7d06856094c4d Mon Sep 17 00:00:00 2001 From: psc-g Date: Mon, 17 Jun 2024 17:15:08 -0400 Subject: [PATCH 04/15] Use unified function for discrete and continuous actions. --- src/python/env.py | 59 +++++++++++------------------------------------ 1 file changed, 13 insertions(+), 46 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index 3d5b7860e..36c871b53 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -41,7 +41,7 @@ def __init__( frameskip: tuple[int, int] | int = 4, repeat_action_probability: float = 0.25, full_action_space: bool = False, - continuous_actions: bool = False, + continuous: bool = False, continuous_action_threshold: float = 0.5, max_num_frames_per_episode: int | None = None, render_mode: Literal["human", "rgb_array"] | None = None, @@ -60,7 +60,7 @@ def __init__( repeat_action_probability: int => Probability to repeat actions, see Machado et al., 2018 full_action_space: bool => Use full action space? - continuous_actions: bool => Use continuous actions? + continuous: bool => Use continuous actions? continuous_action_threshold: float => threshold used for continuous actions. max_num_frames_per_episode: int => Max number of frame per epsiode. Once `max_num_frames_per_episode` is reached the episode is @@ -121,7 +121,7 @@ def __init__( frameskip=frameskip, repeat_action_probability=repeat_action_probability, full_action_space=full_action_space, - continuous_actions=continuous_actions, + continuous=continuous, continuous_action_threshold=continuous_action_threshold, max_num_frames_per_episode=max_num_frames_per_episode, render_mode=render_mode, @@ -157,9 +157,9 @@ def __init__( self.seed_game() self.load_game() - self.continuous_actions = continuous_actions + self.continuous = continuous self.continuous_action_threshold = continuous_action_threshold - if continuous_actions: + if continuous: # We don't need action_set for continuous actions. self._action_set = None # Actions are radius, theta, and fire, where first two are the @@ -239,13 +239,14 @@ def reset( # pyright: ignore[reportIncompatibleMethodOverride] def step( # pyright: ignore[reportIncompatibleMethodOverride] self, - action: int, + action: int | tuple, ) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: """ Perform one agent step, i.e., repeats `action` frameskip # of steps. Args: - action_ind: int => Action index to execute + action_ind: int | tuple => Action index to execute, or tuple of floats + if continuous. Returns: tuple[np.ndarray, float, bool, bool, Dict[str, Any]] => @@ -266,50 +267,16 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] # Frameskip reward = 0.0 for _ in range(frameskip): - reward += self.ale.act(self._action_set[action]) + if self.continuous: + r, theta, fire = action + reward += self.ale.actContinuous(r, theta, fire) + else: + reward += self.ale.act(self._action_set[action]) is_terminal = self.ale.game_over(with_truncation=False) is_truncated = self.ale.game_truncated() return self._get_obs(), reward, is_terminal, is_truncated, self._get_info() - def continuousStep( - self, - r: float, - theta: float, - fire: float, - ) -> Tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: - """Perform one agent step, i.e., repeats `action` frameskip # of steps. - - Args: - r: float => radius of polar coordinate action - theta: float => angle of polar coordinate action - fire: float => continuous fire action - - Returns: - Tuple[np.ndarray, float, bool, Dict[str, Any]] => - observation, reward, terminal, metadata - - Note: `metadata` contains the keys "lives" and "rgb" if - render_mode == 'rgb_array'. - """ - # If frameskip is a length 2 tuple then it's stochastic - # frameskip between [frameskip[0], frameskip[1]] uniformly. - if isinstance(self._frameskip, int): - frameskip = self._frameskip - elif isinstance(self._frameskip, tuple): - frameskip = self.np_random.integers(*self._frameskip) - else: - raise error.Error(f"Invalid frameskip type: {self._frameskip}") - - # Frameskip - reward = 0.0 - for _ in range(frameskip): - reward += self.ale.actContinuous(r, theta, fire) - is_terminal = self.ale.game_over(with_truncation=False) - is_truncated = self.ale.game_truncated() - - return self._get_obs(), reward, is_terminal, is_truncated, self._get_info() - def render(self) -> np.ndarray | None: """ Render is not supported by ALE. We use a paradigm similar to From 3c437353803486734fc4371d2d6cb41a7a05c723 Mon Sep 17 00:00:00 2001 From: psc-g Date: Tue, 2 Jul 2024 15:40:50 -0400 Subject: [PATCH 05/15] Change step argument option from tuple to np.ndarray. --- src/python/env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index 36c871b53..d07d8046b 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -239,14 +239,14 @@ def reset( # pyright: ignore[reportIncompatibleMethodOverride] def step( # pyright: ignore[reportIncompatibleMethodOverride] self, - action: int | tuple, + action: int | np.ndarray, ) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: """ Perform one agent step, i.e., repeats `action` frameskip # of steps. Args: - action_ind: int | tuple => Action index to execute, or tuple of floats - if continuous. + action_ind: int | np.ndarray => Action index to execute, or numpy + array of floats if continuous. Returns: tuple[np.ndarray, float, bool, bool, Dict[str, Any]] => From bf6610de5d78dd06d05639da2f2800f07da5fb6a Mon Sep 17 00:00:00 2001 From: psc-g Date: Fri, 5 Jul 2024 13:26:14 -0400 Subject: [PATCH 06/15] Force action to tuple before splitting into three components. --- src/python/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/env.py b/src/python/env.py index d07d8046b..479908998 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -268,7 +268,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] reward = 0.0 for _ in range(frameskip): if self.continuous: - r, theta, fire = action + r, theta, fire = tuple(action) reward += self.ale.actContinuous(r, theta, fire) else: reward += self.ale.act(self._action_set[action]) From 0eaf8db78105dc17b75278e1f771f6eef9466209 Mon Sep 17 00:00:00 2001 From: psc-g Date: Fri, 5 Jul 2024 13:29:27 -0400 Subject: [PATCH 07/15] Add a check to ensure that 3-dimensional actions are being passed in. --- src/python/env.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/python/env.py b/src/python/env.py index 479908998..c8c42743e 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -268,7 +268,11 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] reward = 0.0 for _ in range(frameskip): if self.continuous: - r, theta, fire = tuple(action) + action = tuple(action) + if len(action) != 3: + raise ValueError('Actions must have 3-dimensions.') + + r, theta, fire = action reward += self.ale.actContinuous(r, theta, fire) else: reward += self.ale.act(self._action_set[action]) From 11a980c2f011d7dc27b2e2d4690ddb9089cc743c Mon Sep 17 00:00:00 2001 From: psc-g Date: Thu, 11 Jul 2024 16:03:48 -0400 Subject: [PATCH 08/15] Fixing style issues --- src/ale_interface.hpp | 2 +- src/python/env.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 126f4e394..832adf592 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -87,7 +87,7 @@ class ALEInterface { // when necessary - this method will keep pressing buttons on the // game over screen. reward_t act(Action action); - + // Applies a continuous action to the game and returns the reward. It is the // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the diff --git a/src/python/env.py b/src/python/env.py index c8c42743e..a1d3cba6b 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -160,21 +160,21 @@ def __init__( self.continuous = continuous self.continuous_action_threshold = continuous_action_threshold if continuous: - # We don't need action_set for continuous actions. - self._action_set = None - # Actions are radius, theta, and fire, where first two are the - # parameters of polar coordinates. - self._action_space = spaces.Box( - np.array([0, -1, 0]).astype(np.float32), - np.array([+1, +1, +1]).astype(np.float32), - ) # radius, theta, fire. First two are polar coordinates. + # We don't need action_set for continuous actions. + self._action_set = None + # Actions are radius, theta, and fire, where first two are the + # parameters of polar coordinates. + self._action_space = spaces.Box( + np.array([0, -1, 0]).astype(np.float32), + np.array([+1, +1, +1]).astype(np.float32), + ) # radius, theta, fire. First two are polar coordinates. else: - self._action_set = ( - self.ale.getLegalActionSet() - if full_action_space - else self.ale.getMinimalActionSet() - ) - self._action_space = spaces.Discrete(len(self._action_set)) + self._action_set = ( + self.ale.getLegalActionSet() + if full_action_space + else self.ale.getMinimalActionSet() + ) + self._action_space = spaces.Discrete(len(self._action_set)) # initialize observation space if self._obs_type == "ram": @@ -270,7 +270,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] if self.continuous: action = tuple(action) if len(action) != 3: - raise ValueError('Actions must have 3-dimensions.') + raise ValueError("Actions must have 3-dimensions.") r, theta, fire = action reward += self.ale.actContinuous(r, theta, fire) From f857a4e56cbdfa740f111ddce80246ad047dbc49 Mon Sep 17 00:00:00 2001 From: psc-g Date: Mon, 15 Jul 2024 13:00:14 -0400 Subject: [PATCH 09/15] Update action_space variable name for gymnasium compatibility. --- src/python/env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index a1d3cba6b..cd3ae580a 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -164,7 +164,7 @@ def __init__( self._action_set = None # Actions are radius, theta, and fire, where first two are the # parameters of polar coordinates. - self._action_space = spaces.Box( + self.action_space = spaces.Box( np.array([0, -1, 0]).astype(np.float32), np.array([+1, +1, +1]).astype(np.float32), ) # radius, theta, fire. First two are polar coordinates. @@ -174,7 +174,7 @@ def __init__( if full_action_space else self.ale.getMinimalActionSet() ) - self._action_space = spaces.Discrete(len(self._action_set)) + self.action_space = spaces.Discrete(len(self._action_set)) # initialize observation space if self._obs_type == "ram": From f9c85602221e6f6a3a02820f42c115a6a0860006 Mon Sep 17 00:00:00 2001 From: psc-g Date: Mon, 15 Jul 2024 13:30:51 -0400 Subject: [PATCH 10/15] Add unit tests for continuous action spaces. --- src/python/env.py | 2 +- tests/python/test_atari_env.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/python/env.py b/src/python/env.py index cd3ae580a..37f1c161e 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -270,7 +270,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] if self.continuous: action = tuple(action) if len(action) != 3: - raise ValueError("Actions must have 3-dimensions.") + raise error.Error("Actions must have 3-dimensions.") r, theta, fire = action reward += self.ale.actContinuous(r, theta, fire) diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 1521cc01b..e50388bb4 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -329,6 +329,35 @@ def test_gym_action_space(tetris_env): assert tetris_env.action_space.n == 18 +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_action_space(tetris_env): + assert isinstance(tetris_env.action_space, gymnasium.spaces.Box) + assert len(tetris_env.action_space.shape) == 1 + assert tetris_env.action_space.shape[0] == 3 + np.testing.assert_array_equal(tetris_env.action_space.low, np.array([0., -1., 0.])) + np.testing.assert_array_equal(tetris_env.action_space.high, np.array([1., 1., 1.])) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_action_sample(tetris_env): + tetris_env.reset(seed=0) + for _ in range(100): + tetris_env.step(tetris_env.action_space.sample()) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_step_with_correct_dimensions(tetris_env): + tetris_env.reset(seed=0) + tetris_env.step([0., -0.5, 0.5]) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_step_fails_with_wrong_dimensions(tetris_env): + tetris_env.reset(seed=0) + with pytest.raises(gymnasium.error.Error): + tetris_env.step([0., 1.]) + + def test_gym_reset_with_infos(tetris_env): pack = tetris_env.reset(seed=0) From 5586b9b7f8d0612427f5f2c8cac5ae85b343e543 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 23 Jul 2024 12:15:46 +0100 Subject: [PATCH 11/15] pre-commit --- tests/python/test_atari_env.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index e50388bb4..4ad1efcf7 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -334,8 +334,12 @@ def test_continuous_action_space(tetris_env): assert isinstance(tetris_env.action_space, gymnasium.spaces.Box) assert len(tetris_env.action_space.shape) == 1 assert tetris_env.action_space.shape[0] == 3 - np.testing.assert_array_equal(tetris_env.action_space.low, np.array([0., -1., 0.])) - np.testing.assert_array_equal(tetris_env.action_space.high, np.array([1., 1., 1.])) + np.testing.assert_array_equal( + tetris_env.action_space.low, np.array([0.0, -1.0, 0.0]) + ) + np.testing.assert_array_equal( + tetris_env.action_space.high, np.array([1.0, 1.0, 1.0]) + ) @pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) @@ -348,14 +352,14 @@ def test_continuous_action_sample(tetris_env): @pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) def test_continuous_step_with_correct_dimensions(tetris_env): tetris_env.reset(seed=0) - tetris_env.step([0., -0.5, 0.5]) + tetris_env.step([0.0, -0.5, 0.5]) @pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) def test_continuous_step_fails_with_wrong_dimensions(tetris_env): tetris_env.reset(seed=0) with pytest.raises(gymnasium.error.Error): - tetris_env.step([0., 1.]) + tetris_env.step([0.0, 1.0]) def test_gym_reset_with_infos(tetris_env): From 76779d77ceba17d9210875c25fa3b34719f9a876 Mon Sep 17 00:00:00 2001 From: psc-g Date: Thu, 25 Jul 2024 17:26:44 +0200 Subject: [PATCH 12/15] Fix action limits --- src/python/env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index 37f1c161e..b1a30fd6e 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -165,8 +165,8 @@ def __init__( # Actions are radius, theta, and fire, where first two are the # parameters of polar coordinates. self.action_space = spaces.Box( - np.array([0, -1, 0]).astype(np.float32), - np.array([+1, +1, +1]).astype(np.float32), + np.array([0, -np.pi, 0]).astype(np.float32), + np.array([+1, +np.pi, +1]).astype(np.float32), ) # radius, theta, fire. First two are polar coordinates. else: self._action_set = ( From 26f777aadc022ee76744846a6d1a87449049efd4 Mon Sep 17 00:00:00 2001 From: psc-g Date: Thu, 25 Jul 2024 17:32:15 +0200 Subject: [PATCH 13/15] Use action.tolist() instead of converting to tuple --- src/python/env.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index b1a30fd6e..113305b2b 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -268,11 +268,10 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] reward = 0.0 for _ in range(frameskip): if self.continuous: - action = tuple(action) if len(action) != 3: raise error.Error("Actions must have 3-dimensions.") - r, theta, fire = action + r, theta, fire = action.tolist() reward += self.ale.actContinuous(r, theta, fire) else: reward += self.ale.act(self._action_set[action]) From 5e42c6aec1587e7ce26926859323320e04d89d2a Mon Sep 17 00:00:00 2001 From: psc-g Date: Thu, 25 Jul 2024 17:48:09 +0200 Subject: [PATCH 14/15] Fix threshold checks for paddles. --- src/environment/ale_state.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index fd8749be5..9c7ae8393 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -307,13 +307,13 @@ void ALEState::applyActionPaddlesContinuous( int delta_a = 0; if (a_x > continuous_action_threshold) { // Right action. delta_a = -PADDLE_DELTA; - } else if (a_x < continuous_action_threshold) { // Left action. + } else if (a_x < -continuous_action_threshold) { // Left action. delta_a = PADDLE_DELTA; } int delta_b = 0; if (b_x > continuous_action_threshold) { // Right action. delta_b = -PADDLE_DELTA; - } else if (b_x < continuous_action_threshold) { // Left action. + } else if (b_x < -continuous_action_threshold) { // Left action. delta_b = PADDLE_DELTA; } From 6aadbb7afbe54f122baa53d0992f5af66f10139e Mon Sep 17 00:00:00 2001 From: psc-g Date: Sun, 28 Jul 2024 09:42:06 +0200 Subject: [PATCH 15/15] Fixing tests --- src/python/env.py | 4 +++- tests/python/test_atari_env.py | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index 113305b2b..129052fc8 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -271,7 +271,9 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] if len(action) != 3: raise error.Error("Actions must have 3-dimensions.") - r, theta, fire = action.tolist() + if isinstance(action, np.ndarray): + action = action.tolist() + r, theta, fire = action reward += self.ale.actContinuous(r, theta, fire) else: reward += self.ale.act(self._action_set[action]) diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 4ad1efcf7..54e2d113e 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -334,11 +334,11 @@ def test_continuous_action_space(tetris_env): assert isinstance(tetris_env.action_space, gymnasium.spaces.Box) assert len(tetris_env.action_space.shape) == 1 assert tetris_env.action_space.shape[0] == 3 - np.testing.assert_array_equal( - tetris_env.action_space.low, np.array([0.0, -1.0, 0.0]) + np.testing.assert_array_almost_equal( + tetris_env.action_space.low, np.array([0.0, -np.pi, 0.0]) ) - np.testing.assert_array_equal( - tetris_env.action_space.high, np.array([1.0, 1.0, 1.0]) + np.testing.assert_array_almost_equal( + tetris_env.action_space.high, np.array([1.0, np.pi, 1.0]) ) @@ -352,7 +352,9 @@ def test_continuous_action_sample(tetris_env): @pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) def test_continuous_step_with_correct_dimensions(tetris_env): tetris_env.reset(seed=0) + # Test with both regular list and numpy array. tetris_env.step([0.0, -0.5, 0.5]) + tetris_env.step(np.array([0.0, -0.5, 0.5])) @pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True)