From b24830ff8b2d160708f94a732b40ba01e1c39560 Mon Sep 17 00:00:00 2001 From: Pablo Samuel Castro Date: Wed, 31 Jul 2024 08:56:54 -0400 Subject: [PATCH 1/3] Add support for continuous actions in the ALE (CALE) (#539) Co-authored-by: Mark Towers --- src/ale_interface.cpp | 8 ++ src/ale_interface.hpp | 6 ++ src/emucore/Settings.cxx | 1 + src/environment/ale_state.cpp | 90 +++++++++++++++++++++ src/environment/ale_state.hpp | 14 ++++ src/environment/stella_environment.cpp | 106 +++++++++++++++++++++++++ src/environment/stella_environment.hpp | 27 +++++++ src/python/__init__.pyi | 1 + src/python/ale_python_interface.hpp | 1 + src/python/env.py | 49 +++++++++--- tests/python/test_atari_env.py | 35 ++++++++ 11 files changed, 328 insertions(+), 10 deletions(-) diff --git a/src/ale_interface.cpp b/src/ale_interface.cpp index 83a146ea2..eae846fdd 100644 --- a/src/ale_interface.cpp +++ b/src/ale_interface.cpp @@ -263,6 +263,14 @@ reward_t ALEInterface::act(Action action) { return environment->act(action, PLAYER_B_NOOP); } +// Applies a continuous action to the game and returns the reward. It is the +// user's responsibility to check if the game has ended and reset +// when necessary - this method will keep pressing buttons on the +// game over screen. +reward_t ALEInterface::actContinuous(float r, float theta, float fire) { + return environment->actContinuous(r, theta, fire, 0.0, 0.0, 0.0); +} + // Returns the vector of modes available for the current game. // This should be called only after the rom is loaded. ModeVect ALEInterface::getAvailableModes() const { diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 509f094d7..832adf592 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -88,6 +88,12 @@ class ALEInterface { // game over screen. reward_t act(Action action); + // Applies a continuous action to the game and returns the reward. It is the + // user's responsibility to check if the game has ended and reset + // when necessary - this method will keep pressing buttons on the + // game over screen. + reward_t actContinuous(float r, float theta, float fire); + // Indicates if the game has ended. bool game_over(bool with_truncation = true) const; diff --git a/src/emucore/Settings.cxx b/src/emucore/Settings.cxx index 8e9193aae..bb11a8ca0 100644 --- a/src/emucore/Settings.cxx +++ b/src/emucore/Settings.cxx @@ -415,6 +415,7 @@ void Settings::setDefaultSettings() { boolSettings.insert(std::pair("send_rgb", false)); intSettings.insert(std::pair("frame_skip", 1)); floatSettings.insert(std::pair("repeat_action_probability", 0.25)); + floatSettings.insert(std::pair("continuous_action_threshold", 0.5)); stringSettings.insert(std::pair("rom_file", "")); // Whether to truncate an episode on loss of life. boolSettings.insert(std::pair("truncate_on_loss_of_life", false)); diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 8f4eafe82..9c7ae8393 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -13,6 +13,7 @@ #include "environment/ale_state.hpp" #include +#include #include #include #include @@ -287,6 +288,47 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, } } +void ALEState::applyActionPaddlesContinuous( + Event* event, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold) { + // Reset keys + resetKeys(event); + + // Convert polar coordinates to x/y position. + float a_x = player_a_r * cos(player_a_theta); + float a_y = player_a_r * sin(player_a_theta); + float b_x = player_b_r * cos(player_b_theta); + float b_y = player_b_r * sin(player_b_theta); + + // First compute whether we should increase or decrease the paddle position + // (for both left and right players) + int delta_a = 0; + if (a_x > continuous_action_threshold) { // Right action. + delta_a = -PADDLE_DELTA; + } else if (a_x < -continuous_action_threshold) { // Left action. + delta_a = PADDLE_DELTA; + } + int delta_b = 0; + if (b_x > continuous_action_threshold) { // Right action. + delta_b = -PADDLE_DELTA; + } else if (b_x < -continuous_action_threshold) { // Left action. + delta_b = PADDLE_DELTA; + } + + // Now update the paddle positions + updatePaddlePositions(event, delta_a, delta_b); + + // Now add the fire event + if (player_a_fire > continuous_action_threshold) { + event->set(Event::PaddleZeroFire, 1); + } + if (player_b_fire > continuous_action_threshold) { + event->set(Event::PaddleOneFire, 1); + } +} + void ALEState::pressSelect(Event* event) { resetKeys(event); event->set(Event::ConsoleSelect, 1); @@ -498,6 +540,54 @@ void ALEState::setActionJoysticks(Event* event, int player_a_action, } } + +void ALEState::setActionJoysticksContinuous( + Event* event, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold) { + // Reset keys + resetKeys(event); + + // Convert polar coordinates to x/y position. + float a_x = player_a_r * cos(player_a_theta); + float a_y = player_a_r * sin(player_a_theta); + float b_x = player_b_r * cos(player_b_theta); + float b_y = player_b_r * sin(player_b_theta); + + // Go through all possible events and add them if joystick position is there. + if (a_x < -continuous_action_threshold) { + event->set(Event::JoystickZeroLeft, 1); + } + if (a_x > continuous_action_threshold) { + event->set(Event::JoystickZeroRight, 1); + } + if (a_y < -continuous_action_threshold) { + event->set(Event::JoystickZeroDown, 1); + } + if (a_y > continuous_action_threshold) { + event->set(Event::JoystickZeroUp, 1); + } + if (player_a_fire > continuous_action_threshold) { + event->set(Event::JoystickZeroFire, 1); + } + if (b_x < -continuous_action_threshold) { + event->set(Event::JoystickOneLeft, 1); + } + if (b_x > continuous_action_threshold) { + event->set(Event::JoystickOneRight, 1); + } + if (b_y < -continuous_action_threshold) { + event->set(Event::JoystickOneDown, 1); + } + if (b_y > continuous_action_threshold) { + event->set(Event::JoystickOneUp, 1); + } + if (player_b_fire > continuous_action_threshold) { + event->set(Event::JoystickOneFire, 1); + } +} + /* *************************************************************************** Function resetKeys Unpresses all control-relevant keys diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index 0f1f5641f..6aad84683 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -64,9 +64,23 @@ class ALEState { * resistances. */ void applyActionPaddles(stella::Event* event_obj, int player_a_action, int player_b_action); + /** Applies paddle continuous actions. This actually modifies the game state + * by updating the paddle resistances. */ + void applyActionPaddlesContinuous( + stella::Event* event_obj, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_treshold = 0.5); /** Sets the joystick events. No effect until the emulator is run forward. */ void setActionJoysticks(stella::Event* event_obj, int player_a_action, int player_b_action); + /** Sets the joystick events for continuous actions. No effect until the + * emulator is run forward. */ + void setActionJoysticksContinuous( + stella::Event* event_obj, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + float continuous_action_threshold = 0.5); void incrementFrame(int steps = 1); diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index a54443826..5e57c8bf9 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -76,6 +76,8 @@ StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings) m_repeat_action_probability = m_osystem->settings().getFloat("repeat_action_probability"); + m_continuous_action_threshold = + m_osystem->settings().getFloat("continuous_action_threshold"); m_frame_skip = m_osystem->settings().getInt("frame_skip"); if (m_frame_skip < 1) { @@ -187,6 +189,49 @@ reward_t StellaEnvironment::act(Action player_a_action, return std::clamp(sum_rewards, m_reward_min, m_reward_max); } +reward_t StellaEnvironment::actContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire) { + // Total reward received as we repeat the action + reward_t sum_rewards = 0; + + Random& rng = getEnvironmentRNG(); + + // Apply the same action for a given number of times... note that act() will refuse to emulate + // past the terminal state + for (size_t i = 0; i < m_frame_skip; i++) { + // Stochastically drop actions, according to m_repeat_action_probability + if (rng.nextDouble() >= m_repeat_action_probability) { + m_player_a_r = player_a_r; + m_player_a_theta = player_a_theta; + m_player_a_fire = player_a_fire; + } + // @todo Possibly optimize by avoiding call to rand() when player B is "off" ? + if (rng.nextDouble() >= m_repeat_action_probability) { + m_player_b_r = player_b_r; + m_player_b_theta = player_b_theta; + m_player_b_fire = player_b_fire; + } + + // If so desired, request one frame's worth of sound (this does nothing if recording + // is not enabled) + m_osystem->sound().recordNextFrame(); + + // Render screen if we're displaying it + m_osystem->screen().render(); + + // Similarly record screen as needed + if (m_screen_exporter.get() != NULL) + m_screen_exporter->saveNext(m_screen); + + // Use the stored actions, which may or may not have changed this frame + sum_rewards += oneStepActContinuous(m_player_a_r, m_player_a_theta, m_player_a_fire, + m_player_b_r, m_player_b_theta, m_player_b_fire); + } + + return sum_rewards; +} + /** This functions emulates a push on the reset button of the console */ void StellaEnvironment::softReset() { emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps); @@ -217,6 +262,29 @@ reward_t StellaEnvironment::oneStepAct(Action player_a_action, return m_settings->getReward(); } +/** Applies the given continuous actions (e.g. updating paddle positions when + * the paddle is used) and performs one simulation step in Stella. */ +reward_t StellaEnvironment::oneStepActContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire) { + // Once in a terminal state, refuse to go any further (special actions must be handled + // outside of this environment; in particular reset() should be called rather than passing + // RESET or SYSTEM_RESET. + if (isTerminal()) + return 0; + + // Convert illegal actions into NOOPs; actions such as reset are always legal + //noopIllegalActions(player_a_action, player_b_action); + + // Emulate in the emulator + emulateContinuous(player_a_r, player_a_theta, player_a_fire, + player_b_r, player_b_theta, player_b_fire); + // Increment the number of frames seen so far + m_state.incrementFrame(); + + return m_settings->getReward(); +} + bool StellaEnvironment::isTerminal() const { return isGameTerminal() || isGameTruncated(); } @@ -287,6 +355,44 @@ void StellaEnvironment::emulate(Action player_a_action, Action player_b_action, processRAM(); } +void StellaEnvironment::emulateContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + size_t num_steps) { + Event* event = m_osystem->event(); + + // Handle paddles separately: we have to manually update the paddle positions at each step + if (m_use_paddles) { + // Run emulator forward for 'num_steps' + for (size_t t = 0; t < num_steps; t++) { + // Update paddle position at every step + m_state.applyActionPaddlesContinuous( + event, + player_a_r, player_a_theta, player_a_fire, + player_b_r, player_b_theta, player_b_fire, + m_continuous_action_threshold); + + m_osystem->console().mediaSource().update(); + m_settings->step(m_osystem->console().system()); + } + } else { + // In joystick mode we only need to set the action events once + m_state.setActionJoysticksContinuous( + event, player_a_r, player_a_theta, player_a_fire, + player_b_r, player_b_theta, player_b_fire, + m_continuous_action_threshold); + + for (size_t t = 0; t < num_steps; t++) { + m_osystem->console().mediaSource().update(); + m_settings->step(m_osystem->console().system()); + } + } + + // Parse screen and RAM into their respective data structures + processScreen(); + processRAM(); +} + /** Accessor methods for the environment state. */ void StellaEnvironment::setState(const ALEState& state) { m_state = state; } diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index 803f1e02b..405fd2ca1 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -60,6 +60,17 @@ class StellaEnvironment { */ reward_t act(Action player_a_action, Action player_b_action); + /** Applies the given continuous actions (e.g. updating paddle positions when + * the paddle is used) and performs one simulation step in Stella. Returns the + * resultant reward. When frame skip is set to > 1, up the corresponding + * number of simulation steps are performed. Note that the post-act() frame + * number might not correspond to the pre-act() frame number plus the frame + * skip. + */ + reward_t actContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire); + /** This functions emulates a push on the reset button of the console */ void softReset(); @@ -121,9 +132,21 @@ class StellaEnvironment { /** This applies an action exactly one time step. Helper function to act(). */ reward_t oneStepAct(Action player_a_action, Action player_b_action); + /** This applies a continuous action exactly one time step. + * Helper function to actContinuous(). + */ + reward_t oneStepActContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire); + + /** Actually emulates the emulator for a given number of steps. */ void emulate(Action player_a_action, Action player_b_action, size_t num_steps = 1); + void emulateContinuous( + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + size_t num_steps = 1); /** Drops illegal actions, such as the fire button in skiing. Note that this is different * from the minimal set of actions. */ @@ -153,6 +176,7 @@ class StellaEnvironment { int m_max_num_frames_per_episode; // Maxmimum number of frames per episode size_t m_frame_skip; // How many frames to emulate per act() float m_repeat_action_probability; // Stochasticity of the environment + float m_continuous_action_threshold; // Continuous action threshold std::unique_ptr m_screen_exporter; // Automatic screen recorder int m_max_lives; // Maximum number of lives at the start of an episode. bool m_truncate_on_loss_of_life; // Whether to truncate episodes on loss of life. @@ -161,6 +185,9 @@ class StellaEnvironment { // The last actions taken by our players Action m_player_a_action, m_player_b_action; + float m_player_a_r, m_player_b_r; + float m_player_a_theta, m_player_b_theta; + float m_player_a_fire, m_player_b_fire; }; } // namespace ale diff --git a/src/python/__init__.pyi b/src/python/__init__.pyi index d1c619727..cbc3f7881 100644 --- a/src/python/__init__.pyi +++ b/src/python/__init__.pyi @@ -103,6 +103,7 @@ class ALEInterface: def __init__(self) -> None: ... @overload def act(self, action: Action) -> int: ... + def actContinuous(self, r: float, theta: float, fire: float) -> int: ... @overload def act(self, action: int) -> int: ... def cloneState(self, *, include_rng: bool = False) -> ALEState: ... diff --git a/src/python/ale_python_interface.hpp b/src/python/ale_python_interface.hpp index 7a6fb057e..6377f8db7 100644 --- a/src/python/ale_python_interface.hpp +++ b/src/python/ale_python_interface.hpp @@ -146,6 +146,7 @@ PYBIND11_MODULE(_ale_py, m) { ale::ALEPythonInterface::act) .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action)) & ale::ALEInterface::act) + .def("actContinuous", &ale::ALEPythonInterface::actContinuous) .def("game_over", &ale::ALEPythonInterface::game_over, py::kw_only(), py::arg("with_truncation") = py::bool_(true)) .def("game_truncated", &ale::ALEPythonInterface::game_truncated) .def("reset_game", &ale::ALEPythonInterface::reset_game) diff --git a/src/python/env.py b/src/python/env.py index abb918f23..129052fc8 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -41,6 +41,8 @@ def __init__( frameskip: tuple[int, int] | int = 4, repeat_action_probability: float = 0.25, full_action_space: bool = False, + continuous: bool = False, + continuous_action_threshold: float = 0.5, max_num_frames_per_episode: int | None = None, render_mode: Literal["human", "rgb_array"] | None = None, ): @@ -58,6 +60,8 @@ def __init__( repeat_action_probability: int => Probability to repeat actions, see Machado et al., 2018 full_action_space: bool => Use full action space? + continuous: bool => Use continuous actions? + continuous_action_threshold: float => threshold used for continuous actions. max_num_frames_per_episode: int => Max number of frame per epsiode. Once `max_num_frames_per_episode` is reached the episode is truncated. @@ -117,6 +121,8 @@ def __init__( frameskip=frameskip, repeat_action_probability=repeat_action_probability, full_action_space=full_action_space, + continuous=continuous, + continuous_action_threshold=continuous_action_threshold, max_num_frames_per_episode=max_num_frames_per_episode, render_mode=render_mode, ) @@ -136,6 +142,8 @@ def __init__( self.ale.setLoggerMode(ale_py.LoggerMode.Error) # Config sticky action prob. self.ale.setFloat("repeat_action_probability", repeat_action_probability) + # Config continuous action threshold (if using continuous actions). + self.ale.setFloat("continuous_action_threshold", continuous_action_threshold) if max_num_frames_per_episode is not None: self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode) @@ -149,13 +157,24 @@ def __init__( self.seed_game() self.load_game() - # initialize action space - self._action_set = ( - self.ale.getLegalActionSet() - if full_action_space - else self.ale.getMinimalActionSet() - ) - self.action_space = spaces.Discrete(len(self._action_set)) + self.continuous = continuous + self.continuous_action_threshold = continuous_action_threshold + if continuous: + # We don't need action_set for continuous actions. + self._action_set = None + # Actions are radius, theta, and fire, where first two are the + # parameters of polar coordinates. + self.action_space = spaces.Box( + np.array([0, -np.pi, 0]).astype(np.float32), + np.array([+1, +np.pi, +1]).astype(np.float32), + ) # radius, theta, fire. First two are polar coordinates. + else: + self._action_set = ( + self.ale.getLegalActionSet() + if full_action_space + else self.ale.getMinimalActionSet() + ) + self.action_space = spaces.Discrete(len(self._action_set)) # initialize observation space if self._obs_type == "ram": @@ -220,13 +239,14 @@ def reset( # pyright: ignore[reportIncompatibleMethodOverride] def step( # pyright: ignore[reportIncompatibleMethodOverride] self, - action: int, + action: int | np.ndarray, ) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: """ Perform one agent step, i.e., repeats `action` frameskip # of steps. Args: - action_ind: int => Action index to execute + action_ind: int | np.ndarray => Action index to execute, or numpy + array of floats if continuous. Returns: tuple[np.ndarray, float, bool, bool, Dict[str, Any]] => @@ -247,7 +267,16 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] # Frameskip reward = 0.0 for _ in range(frameskip): - reward += self.ale.act(self._action_set[action]) + if self.continuous: + if len(action) != 3: + raise error.Error("Actions must have 3-dimensions.") + + if isinstance(action, np.ndarray): + action = action.tolist() + r, theta, fire = action + reward += self.ale.actContinuous(r, theta, fire) + else: + reward += self.ale.act(self._action_set[action]) is_terminal = self.ale.game_over(with_truncation=False) is_truncated = self.ale.game_truncated() diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 1521cc01b..54e2d113e 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -329,6 +329,41 @@ def test_gym_action_space(tetris_env): assert tetris_env.action_space.n == 18 +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_action_space(tetris_env): + assert isinstance(tetris_env.action_space, gymnasium.spaces.Box) + assert len(tetris_env.action_space.shape) == 1 + assert tetris_env.action_space.shape[0] == 3 + np.testing.assert_array_almost_equal( + tetris_env.action_space.low, np.array([0.0, -np.pi, 0.0]) + ) + np.testing.assert_array_almost_equal( + tetris_env.action_space.high, np.array([1.0, np.pi, 1.0]) + ) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_action_sample(tetris_env): + tetris_env.reset(seed=0) + for _ in range(100): + tetris_env.step(tetris_env.action_space.sample()) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_step_with_correct_dimensions(tetris_env): + tetris_env.reset(seed=0) + # Test with both regular list and numpy array. + tetris_env.step([0.0, -0.5, 0.5]) + tetris_env.step(np.array([0.0, -0.5, 0.5])) + + +@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) +def test_continuous_step_fails_with_wrong_dimensions(tetris_env): + tetris_env.reset(seed=0) + with pytest.raises(gymnasium.error.Error): + tetris_env.step([0.0, 1.0]) + + def test_gym_reset_with_infos(tetris_env): pack = tetris_env.reset(seed=0) From 11bbfdbffa99e5ea3fb6423541c631f3c4643f9e Mon Sep 17 00:00:00 2001 From: Jet <38184875+jjshoots@users.noreply.github.com> Date: Fri, 9 Aug 2024 01:20:40 +0900 Subject: [PATCH 2/3] Pure Continuous version of ALE on CPP (#550) Co-authored-by: Jet --- pyproject.toml | 2 +- src/ale_interface.cpp | 13 +- src/ale_interface.hpp | 8 +- src/emucore/Settings.cxx | 1 - src/environment/ale_state.cpp | 156 ++---------------- src/environment/ale_state.hpp | 34 ++-- src/environment/stella_environment.cpp | 149 ++++------------- src/environment/stella_environment.hpp | 24 +-- .../stella_environment_wrapper.cpp | 8 +- .../stella_environment_wrapper.hpp | 3 +- src/python/__init__.pyi | 5 +- src/python/ale_python_interface.hpp | 5 +- src/python/env.py | 138 ++++++++++++---- tests/python/test_atari_env.py | 54 +++--- 14 files changed, 209 insertions(+), 391 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd8dd0d0a..250e26901 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ authors = [ {name = "Michael Bowling"}, ] maintainers = [ - { name = "Farama Foundation", email = "contact@farama.org" }, + {name = "Farama Foundation", email = "contact@farama.org"}, {name = "Jesse Farebrother", email = "jfarebro@cs.mcgill.ca"}, ] classifiers = [ diff --git a/src/ale_interface.cpp b/src/ale_interface.cpp index eae846fdd..9799bd67f 100644 --- a/src/ale_interface.cpp +++ b/src/ale_interface.cpp @@ -259,16 +259,9 @@ int ALEInterface::lives() { // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the // game over screen. -reward_t ALEInterface::act(Action action) { - return environment->act(action, PLAYER_B_NOOP); -} - -// Applies a continuous action to the game and returns the reward. It is the -// user's responsibility to check if the game has ended and reset -// when necessary - this method will keep pressing buttons on the -// game over screen. -reward_t ALEInterface::actContinuous(float r, float theta, float fire) { - return environment->actContinuous(r, theta, fire, 0.0, 0.0, 0.0); +// Intentionally set player B actions to 0 since we are in single player mode +reward_t ALEInterface::act(Action action, float paddle_strength) { + return environment->act(action, PLAYER_B_NOOP, paddle_strength, 0.0); } // Returns the vector of modes available for the current game. diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 832adf592..f7b4d3f1c 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -86,13 +86,7 @@ class ALEInterface { // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the // game over screen. - reward_t act(Action action); - - // Applies a continuous action to the game and returns the reward. It is the - // user's responsibility to check if the game has ended and reset - // when necessary - this method will keep pressing buttons on the - // game over screen. - reward_t actContinuous(float r, float theta, float fire); + reward_t act(Action action, float paddle_strength = 1.0); // Indicates if the game has ended. bool game_over(bool with_truncation = true) const; diff --git a/src/emucore/Settings.cxx b/src/emucore/Settings.cxx index bb11a8ca0..8e9193aae 100644 --- a/src/emucore/Settings.cxx +++ b/src/emucore/Settings.cxx @@ -415,7 +415,6 @@ void Settings::setDefaultSettings() { boolSettings.insert(std::pair("send_rgb", false)); intSettings.insert(std::pair("frame_skip", 1)); floatSettings.insert(std::pair("repeat_action_probability", 0.25)); - floatSettings.insert(std::pair("continuous_action_threshold", 0.5)); stringSettings.insert(std::pair("rom_file", "")); // Whether to truncate an episode on loss of life. boolSettings.insert(std::pair("truncate_on_loss_of_life", false)); diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 9c7ae8393..4b19c4622 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -189,16 +189,14 @@ void ALEState::updatePaddlePositions(Event* event, int delta_left, setPaddles(event, m_left_paddle, m_right_paddle); } -void ALEState::applyActionPaddles(Event* event, int player_a_action, - int player_b_action) { +void ALEState::applyActionPaddles(Event* event, + int player_a_action, float paddle_a_strength, + int player_b_action, float paddle_b_strength) { // Reset keys resetKeys(event); - // First compute whether we should increase or decrease the paddle position - // (for both left and right players) - int delta_left; - int delta_right; - + int delta_a = 0; + int delta_b = 0; switch (player_a_action) { case PLAYER_A_RIGHT: case PLAYER_A_RIGHTFIRE: @@ -206,7 +204,7 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, case PLAYER_A_DOWNRIGHT: case PLAYER_A_UPRIGHTFIRE: case PLAYER_A_DOWNRIGHTFIRE: - delta_left = -PADDLE_DELTA; + delta_a = static_cast(-PADDLE_DELTA * fabs(paddle_a_strength)); break; case PLAYER_A_LEFT: @@ -215,10 +213,10 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, case PLAYER_A_DOWNLEFT: case PLAYER_A_UPLEFTFIRE: case PLAYER_A_DOWNLEFTFIRE: - delta_left = PADDLE_DELTA; + delta_a = static_cast(PADDLE_DELTA * fabs(paddle_a_strength)); break; + default: - delta_left = 0; break; } @@ -229,7 +227,7 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, case PLAYER_B_DOWNRIGHT: case PLAYER_B_UPRIGHTFIRE: case PLAYER_B_DOWNRIGHTFIRE: - delta_right = -PADDLE_DELTA; + delta_b = static_cast(-PADDLE_DELTA * fabs(paddle_b_strength)); break; case PLAYER_B_LEFT: @@ -238,15 +236,15 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, case PLAYER_B_DOWNLEFT: case PLAYER_B_UPLEFTFIRE: case PLAYER_B_DOWNLEFTFIRE: - delta_right = PADDLE_DELTA; + delta_b = static_cast(PADDLE_DELTA * fabs(paddle_b_strength)); break; + default: - delta_right = 0; break; } // Now update the paddle positions - updatePaddlePositions(event, delta_left, delta_right); + updatePaddlePositions(event, delta_a, delta_b); // Handle reset if (player_a_action == RESET || player_b_action == RESET) @@ -288,47 +286,6 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, } } -void ALEState::applyActionPaddlesContinuous( - Event* event, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold) { - // Reset keys - resetKeys(event); - - // Convert polar coordinates to x/y position. - float a_x = player_a_r * cos(player_a_theta); - float a_y = player_a_r * sin(player_a_theta); - float b_x = player_b_r * cos(player_b_theta); - float b_y = player_b_r * sin(player_b_theta); - - // First compute whether we should increase or decrease the paddle position - // (for both left and right players) - int delta_a = 0; - if (a_x > continuous_action_threshold) { // Right action. - delta_a = -PADDLE_DELTA; - } else if (a_x < -continuous_action_threshold) { // Left action. - delta_a = PADDLE_DELTA; - } - int delta_b = 0; - if (b_x > continuous_action_threshold) { // Right action. - delta_b = -PADDLE_DELTA; - } else if (b_x < -continuous_action_threshold) { // Left action. - delta_b = PADDLE_DELTA; - } - - // Now update the paddle positions - updatePaddlePositions(event, delta_a, delta_b); - - // Now add the fire event - if (player_a_fire > continuous_action_threshold) { - event->set(Event::PaddleZeroFire, 1); - } - if (player_b_fire > continuous_action_threshold) { - event->set(Event::PaddleOneFire, 1); - } -} - void ALEState::pressSelect(Event* event) { resetKeys(event); event->set(Event::ConsoleSelect, 1); @@ -343,93 +300,75 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { event->set(Event::ConsoleRightDifficultyB, !((value & 2) >> 1)); } -void ALEState::setActionJoysticks(Event* event, int player_a_action, - int player_b_action) { +void ALEState::applyActionJoysticks(Event* event, + int player_a_action, int player_b_action) { // Reset keys resetKeys(event); - switch (player_a_action) { case PLAYER_A_NOOP: break; - case PLAYER_A_FIRE: event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_UP: event->set(Event::JoystickZeroUp, 1); break; - case PLAYER_A_RIGHT: event->set(Event::JoystickZeroRight, 1); break; - case PLAYER_A_LEFT: event->set(Event::JoystickZeroLeft, 1); break; - case PLAYER_A_DOWN: event->set(Event::JoystickZeroDown, 1); break; - case PLAYER_A_UPRIGHT: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroRight, 1); break; - case PLAYER_A_UPLEFT: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroLeft, 1); break; - case PLAYER_A_DOWNRIGHT: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroRight, 1); break; - case PLAYER_A_DOWNLEFT: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroLeft, 1); break; - case PLAYER_A_UPFIRE: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_RIGHTFIRE: event->set(Event::JoystickZeroRight, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_LEFTFIRE: event->set(Event::JoystickZeroLeft, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_DOWNFIRE: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_UPRIGHTFIRE: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroRight, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_UPLEFTFIRE: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroLeft, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_DOWNRIGHTFIRE: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroRight, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_DOWNLEFTFIRE: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroLeft, 1); @@ -437,94 +376,77 @@ void ALEState::setActionJoysticks(Event* event, int player_a_action, break; case RESET: event->set(Event::ConsoleReset, 1); + Logger::Info << "Sending Reset...\n"; break; default: Logger::Error << "Invalid Player A Action: " << player_a_action << "\n"; std::exit(-1); } - switch (player_b_action) { case PLAYER_B_NOOP: break; - case PLAYER_B_FIRE: event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_UP: event->set(Event::JoystickOneUp, 1); break; - case PLAYER_B_RIGHT: event->set(Event::JoystickOneRight, 1); break; - case PLAYER_B_LEFT: event->set(Event::JoystickOneLeft, 1); break; - case PLAYER_B_DOWN: event->set(Event::JoystickOneDown, 1); break; - case PLAYER_B_UPRIGHT: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneRight, 1); break; - case PLAYER_B_UPLEFT: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneLeft, 1); break; - case PLAYER_B_DOWNRIGHT: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneRight, 1); break; - case PLAYER_B_DOWNLEFT: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneLeft, 1); break; - case PLAYER_B_UPFIRE: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_RIGHTFIRE: event->set(Event::JoystickOneRight, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_LEFTFIRE: event->set(Event::JoystickOneLeft, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_DOWNFIRE: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_UPRIGHTFIRE: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneRight, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_UPLEFTFIRE: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneLeft, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_DOWNRIGHTFIRE: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneRight, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_DOWNLEFTFIRE: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneLeft, 1); @@ -540,54 +462,6 @@ void ALEState::setActionJoysticks(Event* event, int player_a_action, } } - -void ALEState::setActionJoysticksContinuous( - Event* event, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold) { - // Reset keys - resetKeys(event); - - // Convert polar coordinates to x/y position. - float a_x = player_a_r * cos(player_a_theta); - float a_y = player_a_r * sin(player_a_theta); - float b_x = player_b_r * cos(player_b_theta); - float b_y = player_b_r * sin(player_b_theta); - - // Go through all possible events and add them if joystick position is there. - if (a_x < -continuous_action_threshold) { - event->set(Event::JoystickZeroLeft, 1); - } - if (a_x > continuous_action_threshold) { - event->set(Event::JoystickZeroRight, 1); - } - if (a_y < -continuous_action_threshold) { - event->set(Event::JoystickZeroDown, 1); - } - if (a_y > continuous_action_threshold) { - event->set(Event::JoystickZeroUp, 1); - } - if (player_a_fire > continuous_action_threshold) { - event->set(Event::JoystickZeroFire, 1); - } - if (b_x < -continuous_action_threshold) { - event->set(Event::JoystickOneLeft, 1); - } - if (b_x > continuous_action_threshold) { - event->set(Event::JoystickOneRight, 1); - } - if (b_y < -continuous_action_threshold) { - event->set(Event::JoystickOneDown, 1); - } - if (b_y > continuous_action_threshold) { - event->set(Event::JoystickOneUp, 1); - } - if (player_b_fire > continuous_action_threshold) { - event->set(Event::JoystickOneFire, 1); - } -} - /* *************************************************************************** Function resetKeys Unpresses all control-relevant keys diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index 6aad84683..3c0d2a529 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -58,29 +58,17 @@ class ALEState { void resetPaddles(stella::Event*); //Apply the special select action - void pressSelect(stella::Event* event_obj); + void pressSelect(stella::Event* event); - /** Applies paddle actions. This actually modifies the game state by updating the paddle - * resistances. */ - void applyActionPaddles(stella::Event* event_obj, int player_a_action, - int player_b_action); /** Applies paddle continuous actions. This actually modifies the game state * by updating the paddle resistances. */ - void applyActionPaddlesContinuous( - stella::Event* event_obj, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_treshold = 0.5); + void applyActionPaddles(stella::Event* event, + int player_a_action, float paddle_a_strength, + int player_b_action, float paddle_b_strength); + /** Sets the joystick events. No effect until the emulator is run forward. */ - void setActionJoysticks(stella::Event* event_obj, int player_a_action, - int player_b_action); - /** Sets the joystick events for continuous actions. No effect until the - * emulator is run forward. */ - void setActionJoysticksContinuous( - stella::Event* event_obj, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold = 0.5); + void applyActionJoysticks(stella::Event* event, + int player_a_action, int player_b_action); void incrementFrame(int steps = 1); @@ -123,22 +111,22 @@ class ALEState { ALEState save(stella::OSystem* osystem, RomSettings* settings, std::optional rng, std::string md5); /** Reset key presses */ - void resetKeys(stella::Event* event_obj); + void resetKeys(stella::Event* event); /** Sets the paddle to a given position */ - void setPaddles(stella::Event* event_obj, int left, int right); + void setPaddles(stella::Event* event, int left, int right); /** Set the paddle min/max values */ void setPaddleLimits(int paddle_min_val, int paddle_max_val); /** Updates the paddle position by a delta amount. */ - void updatePaddlePositions(stella::Event* event_obj, int delta_x, int delta_y); + void updatePaddlePositions(stella::Event* event, int delta_x, int delta_y); /** Calculates the Paddle resistance, based on the given x val */ int calcPaddleResistance(int x_val); /** Applies the current difficulty setting, which is effectively part of the action */ - void setDifficultySwitches(stella::Event* event_obj, unsigned int value); + void setDifficultySwitches(stella::Event* event, unsigned int value); private: int m_left_paddle; // Current value for the left-paddle diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index 5e57c8bf9..3c16e9f9e 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -76,8 +76,6 @@ StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings) m_repeat_action_probability = m_osystem->settings().getFloat("repeat_action_probability"); - m_continuous_action_threshold = - m_osystem->settings().getFloat("continuous_action_threshold"); m_frame_skip = m_osystem->settings().getInt("frame_skip"); if (m_frame_skip < 1) { @@ -109,7 +107,7 @@ void StellaEnvironment::reset() { int noopSteps; noopSteps = 60; - emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, noopSteps); + emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, 1.0, 1.0, noopSteps); // Reset the emulator softReset(); @@ -124,7 +122,7 @@ void StellaEnvironment::reset() { // Apply necessary actions specified by the rom itself ActionVect startingActions = m_settings->getStartingActions(); for (size_t i = 0; i < startingActions.size(); i++) { - emulate(startingActions[i], PLAYER_B_NOOP); + emulate(startingActions[i], PLAYER_B_NOOP, 1.0, 1.0); } } @@ -154,44 +152,8 @@ void StellaEnvironment::noopIllegalActions(Action& player_a_action, player_b_action = (Action)PLAYER_B_NOOP; } -reward_t StellaEnvironment::act(Action player_a_action, - Action player_b_action) { - // Total reward received as we repeat the action - reward_t sum_rewards = 0; - - Random& rng = getEnvironmentRNG(); - - // Apply the same action for a given number of times... note that act() will refuse to emulate - // past the terminal state - for (size_t i = 0; i < m_frame_skip; i++) { - // Stochastically drop actions, according to m_repeat_action_probability - if (rng.nextDouble() >= m_repeat_action_probability) - m_player_a_action = player_a_action; - // @todo Possibly optimize by avoiding call to rand() when player B is "off" ? - if (rng.nextDouble() >= m_repeat_action_probability) - m_player_b_action = player_b_action; - - // If so desired, request one frame's worth of sound (this does nothing if recording - // is not enabled) - m_osystem->sound().recordNextFrame(); - - // Render screen if we're displaying it - m_osystem->screen().render(); - - // Similarly record screen as needed - if (m_screen_exporter.get() != NULL) - m_screen_exporter->saveNext(m_screen); - - // Use the stored actions, which may or may not have changed this frame - sum_rewards += oneStepAct(m_player_a_action, m_player_b_action); - } - - return std::clamp(sum_rewards, m_reward_min, m_reward_max); -} - -reward_t StellaEnvironment::actContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire) { +reward_t StellaEnvironment::act(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength) { // Total reward received as we repeat the action reward_t sum_rewards = 0; @@ -202,15 +164,13 @@ reward_t StellaEnvironment::actContinuous( for (size_t i = 0; i < m_frame_skip; i++) { // Stochastically drop actions, according to m_repeat_action_probability if (rng.nextDouble() >= m_repeat_action_probability) { - m_player_a_r = player_a_r; - m_player_a_theta = player_a_theta; - m_player_a_fire = player_a_fire; + m_player_a_action = player_a_action; + m_paddle_a_strength = paddle_a_strength; } // @todo Possibly optimize by avoiding call to rand() when player B is "off" ? if (rng.nextDouble() >= m_repeat_action_probability) { - m_player_b_r = player_b_r; - m_player_b_theta = player_b_theta; - m_player_b_fire = player_b_fire; + m_player_b_action = player_b_action; + m_paddle_b_strength = paddle_b_strength; } // If so desired, request one frame's worth of sound (this does nothing if recording @@ -225,16 +185,16 @@ reward_t StellaEnvironment::actContinuous( m_screen_exporter->saveNext(m_screen); // Use the stored actions, which may or may not have changed this frame - sum_rewards += oneStepActContinuous(m_player_a_r, m_player_a_theta, m_player_a_fire, - m_player_b_r, m_player_b_theta, m_player_b_fire); + sum_rewards += oneStepAct(m_player_a_action, m_player_b_action, + m_paddle_a_strength, m_paddle_b_strength); } - return sum_rewards; + return std::clamp(sum_rewards, m_reward_min, m_reward_max); } /** This functions emulates a push on the reset button of the console */ void StellaEnvironment::softReset() { - emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps); + emulate(RESET, PLAYER_B_NOOP, 1.0, 1.0, m_num_reset_steps); // Reset previous actions to NOOP for correct action repeating m_player_a_action = PLAYER_A_NOOP; @@ -243,8 +203,8 @@ void StellaEnvironment::softReset() { /** Applies the given actions (e.g. updating paddle positions when the paddle is used) * and performs one simulation step in Stella. */ -reward_t StellaEnvironment::oneStepAct(Action player_a_action, - Action player_b_action) { +reward_t StellaEnvironment::oneStepAct(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength) { // Once in a terminal state, refuse to go any further (special actions must be handled // outside of this environment; in particular reset() should be called rather than passing // RESET or SYSTEM_RESET. @@ -255,30 +215,8 @@ reward_t StellaEnvironment::oneStepAct(Action player_a_action, noopIllegalActions(player_a_action, player_b_action); // Emulate in the emulator - emulate(player_a_action, player_b_action); - // Increment the number of frames seen so far - m_state.incrementFrame(); - - return m_settings->getReward(); -} - -/** Applies the given continuous actions (e.g. updating paddle positions when - * the paddle is used) and performs one simulation step in Stella. */ -reward_t StellaEnvironment::oneStepActContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire) { - // Once in a terminal state, refuse to go any further (special actions must be handled - // outside of this environment; in particular reset() should be called rather than passing - // RESET or SYSTEM_RESET. - if (isTerminal()) - return 0; - - // Convert illegal actions into NOOPs; actions such as reset are always legal - //noopIllegalActions(player_a_action, player_b_action); - - // Emulate in the emulator - emulateContinuous(player_a_r, player_a_theta, player_a_fire, - player_b_r, player_b_theta, player_b_fire); + emulate(player_a_action, player_b_action, + paddle_a_strength, paddle_b_strength); // Increment the number of frames seen so far m_state.incrementFrame(); @@ -314,7 +252,7 @@ void StellaEnvironment::pressSelect(size_t num_steps) { } processScreen(); processRAM(); - emulate(PLAYER_A_NOOP, PLAYER_B_NOOP); + emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, 1.0, 1.0); m_state.incrementFrame(); } @@ -326,39 +264,11 @@ void StellaEnvironment::setMode(game_mode_t value) { m_state.setCurrentMode(value); } -void StellaEnvironment::emulate(Action player_a_action, Action player_b_action, - size_t num_steps) { - Event* event = m_osystem->event(); - - // Handle paddles separately: we have to manually update the paddle positions at each step - if (m_use_paddles) { - // Run emulator forward for 'num_steps' - for (size_t t = 0; t < num_steps; t++) { - // Update paddle position at every step - m_state.applyActionPaddles(event, player_a_action, player_b_action); - - m_osystem->console().mediaSource().update(); - m_settings->step(m_osystem->console().system()); - } - } else { - // In joystick mode we only need to set the action events once - m_state.setActionJoysticks(event, player_a_action, player_b_action); - - for (size_t t = 0; t < num_steps; t++) { - m_osystem->console().mediaSource().update(); - m_settings->step(m_osystem->console().system()); - } - } - - // Parse screen and RAM into their respective data structures - processScreen(); - processRAM(); -} - -void StellaEnvironment::emulateContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - size_t num_steps) { +void StellaEnvironment::emulate( + Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength, + size_t num_steps +) { Event* event = m_osystem->event(); // Handle paddles separately: we have to manually update the paddle positions at each step @@ -366,21 +276,18 @@ void StellaEnvironment::emulateContinuous( // Run emulator forward for 'num_steps' for (size_t t = 0; t < num_steps; t++) { // Update paddle position at every step - m_state.applyActionPaddlesContinuous( - event, - player_a_r, player_a_theta, player_a_fire, - player_b_r, player_b_theta, player_b_fire, - m_continuous_action_threshold); + m_state.applyActionPaddles( + event, + player_a_action, paddle_a_strength, + player_b_action, paddle_b_strength + ); m_osystem->console().mediaSource().update(); m_settings->step(m_osystem->console().system()); } } else { // In joystick mode we only need to set the action events once - m_state.setActionJoysticksContinuous( - event, player_a_r, player_a_theta, player_a_fire, - player_b_r, player_b_theta, player_b_fire, - m_continuous_action_threshold); + m_state.applyActionJoysticks(event, player_a_action, player_b_action); for (size_t t = 0; t < num_steps; t++) { m_osystem->console().mediaSource().update(); diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index 405fd2ca1..f2a59a0b2 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -58,7 +58,8 @@ class StellaEnvironment { * Note that the post-act() frame number might not correspond to the pre-act() frame * number plus the frame skip. */ - reward_t act(Action player_a_action, Action player_b_action); + reward_t act(Action player_a_action, Action player_b_action, + float paddle_a_strength = 1.0, float paddle_b_strength = 1.0); /** Applies the given continuous actions (e.g. updating paddle positions when * the paddle is used) and performs one simulation step in Stella. Returns the @@ -67,9 +68,6 @@ class StellaEnvironment { * number might not correspond to the pre-act() frame number plus the frame * skip. */ - reward_t actContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire); /** This functions emulates a push on the reset button of the console */ void softReset(); @@ -130,23 +128,13 @@ class StellaEnvironment { private: /** This applies an action exactly one time step. Helper function to act(). */ - reward_t oneStepAct(Action player_a_action, Action player_b_action); - - /** This applies a continuous action exactly one time step. - * Helper function to actContinuous(). - */ - reward_t oneStepActContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire); - + reward_t oneStepAct(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength); /** Actually emulates the emulator for a given number of steps. */ void emulate(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength, size_t num_steps = 1); - void emulateContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - size_t num_steps = 1); /** Drops illegal actions, such as the fire button in skiing. Note that this is different * from the minimal set of actions. */ @@ -176,7 +164,6 @@ class StellaEnvironment { int m_max_num_frames_per_episode; // Maxmimum number of frames per episode size_t m_frame_skip; // How many frames to emulate per act() float m_repeat_action_probability; // Stochasticity of the environment - float m_continuous_action_threshold; // Continuous action threshold std::unique_ptr m_screen_exporter; // Automatic screen recorder int m_max_lives; // Maximum number of lives at the start of an episode. bool m_truncate_on_loss_of_life; // Whether to truncate episodes on loss of life. @@ -185,6 +172,7 @@ class StellaEnvironment { // The last actions taken by our players Action m_player_a_action, m_player_b_action; + float m_paddle_a_strength, m_paddle_b_strength; float m_player_a_r, m_player_b_r; float m_player_a_theta, m_player_b_theta; float m_player_a_fire, m_player_b_fire; diff --git a/src/environment/stella_environment_wrapper.cpp b/src/environment/stella_environment_wrapper.cpp index f518943e6..34fe8da7c 100644 --- a/src/environment/stella_environment_wrapper.cpp +++ b/src/environment/stella_environment_wrapper.cpp @@ -9,9 +9,11 @@ StellaEnvironmentWrapper::StellaEnvironmentWrapper( StellaEnvironment& environment) : m_environment(environment) {} -reward_t StellaEnvironmentWrapper::act(Action player_a_action, - Action player_b_action) { - return m_environment.act(player_a_action, player_b_action); +reward_t StellaEnvironmentWrapper::act(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength) { + return m_environment.act(player_a_action, player_b_action, + paddle_a_strength, paddle_b_strength + ); } void StellaEnvironmentWrapper::softReset() { m_environment.softReset(); } diff --git a/src/environment/stella_environment_wrapper.hpp b/src/environment/stella_environment_wrapper.hpp index a6f4fbe65..9e97ee031 100644 --- a/src/environment/stella_environment_wrapper.hpp +++ b/src/environment/stella_environment_wrapper.hpp @@ -30,7 +30,8 @@ class StellaEnvironmentWrapper { // stella_environment.hpp. public: StellaEnvironmentWrapper(StellaEnvironment& environment); - reward_t act(Action player_a_action, Action player_b_action); + reward_t act(Action player_a_action, Action player_b_action, + float paddle_a_strength = 1.0, float paddle_b_strength = 1.0); void softReset(); void pressSelect(size_t num_steps = 1); stella::Random& getEnvironmentRNG(); diff --git a/src/python/__init__.pyi b/src/python/__init__.pyi index cbc3f7881..4025420f0 100644 --- a/src/python/__init__.pyi +++ b/src/python/__init__.pyi @@ -102,10 +102,9 @@ class ALEState: class ALEInterface: def __init__(self) -> None: ... @overload - def act(self, action: Action) -> int: ... - def actContinuous(self, r: float, theta: float, fire: float) -> int: ... + def act(self, action: Action, paddle_strength: float = 1.0) -> int: ... @overload - def act(self, action: int) -> int: ... + def act(self, action: int, paddle_strength: float = 1.0) -> int: ... def cloneState(self, *, include_rng: bool = False) -> ALEState: ... def cloneSystemState(self) -> ALEState: ... def game_over(self, *, with_truncation: bool = True) -> bool: ... diff --git a/src/python/ale_python_interface.hpp b/src/python/ale_python_interface.hpp index 6377f8db7..d6cb5920a 100644 --- a/src/python/ale_python_interface.hpp +++ b/src/python/ale_python_interface.hpp @@ -144,9 +144,12 @@ PYBIND11_MODULE(_ale_py, m) { .def_static("isSupportedROM", &ale::ALEInterface::isSupportedROM) .def("act", (ale::reward_t(ale::ALEPythonInterface::*)(uint32_t)) & ale::ALEPythonInterface::act) + .def("act", (ale::reward_t(ale::ALEPythonInterface::*)(uint32_t, float)) & + ale::ALEPythonInterface::act) .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action)) & ale::ALEInterface::act) - .def("actContinuous", &ale::ALEPythonInterface::actContinuous) + .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action, float)) & + ale::ALEInterface::act) .def("game_over", &ale::ALEPythonInterface::game_over, py::kw_only(), py::arg("with_truncation") = py::bool_(true)) .def("game_truncated", &ale::ALEPythonInterface::game_truncated) .def("reset_game", &ale::ALEPythonInterface::reset_game) diff --git a/src/python/env.py b/src/python/env.py index 129052fc8..db9801bc2 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -1,7 +1,9 @@ from __future__ import annotations import sys +from functools import lru_cache from typing import Any, Literal +from warnings import warn import ale_py import gymnasium @@ -142,8 +144,6 @@ def __init__( self.ale.setLoggerMode(ale_py.LoggerMode.Error) # Config sticky action prob. self.ale.setFloat("repeat_action_probability", repeat_action_probability) - # Config continuous action threshold (if using continuous actions). - self.ale.setFloat("continuous_action_threshold", continuous_action_threshold) if max_num_frames_per_episode is not None: self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode) @@ -157,23 +157,30 @@ def __init__( self.seed_game() self.load_game() + # get the set of legal actions + if continuous and not full_action_space: + warn( + "`continuous` is set to `True`, but `full_action_space` is set to `False`. " + "This will error out when the continuous actions are discretized to illegal action spaces. " + "Therefore, `full_action_space` has been automatically set to `True`." + ) + self._action_set = ( + self.ale.getLegalActionSet() + if (full_action_space or continuous) + else self.ale.getMinimalActionSet() + ) + + # action space self.continuous = continuous self.continuous_action_threshold = continuous_action_threshold if continuous: - # We don't need action_set for continuous actions. - self._action_set = None # Actions are radius, theta, and fire, where first two are the # parameters of polar coordinates. self.action_space = spaces.Box( - np.array([0, -np.pi, 0]).astype(np.float32), - np.array([+1, +np.pi, +1]).astype(np.float32), + np.array([0.0, -np.pi, 0.0]).astype(np.float32), + np.array([1.0, np.pi, 1.0]).astype(np.float32), ) # radius, theta, fire. First two are polar coordinates. else: - self._action_set = ( - self.ale.getLegalActionSet() - if full_action_space - else self.ale.getMinimalActionSet() - ) self.action_space = spaces.Discrete(len(self._action_set)) # initialize observation space @@ -245,8 +252,9 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] Perform one agent step, i.e., repeats `action` frameskip # of steps. Args: - action_ind: int | np.ndarray => Action index to execute, or numpy - array of floats if continuous. + action: int | np.ndarray => + if `continuous=False` -> action index to execute + if `continuous=True` -> numpy array of r, theta, fire Returns: tuple[np.ndarray, float, bool, bool, Dict[str, Any]] => @@ -264,19 +272,33 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] else: raise error.Error(f"Invalid frameskip type: {self._frameskip}") + # action formatting + if self.continuous: + # compute the x, y, fire of the joystick + assert isinstance(action, np.ndarray) + x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) + action_idx = self.map_action_idx( + left_center_right=( + -int(x < self.continuous_action_threshold) + + int(x > self.continuous_action_threshold) + ), + down_center_up=( + -int(y < self.continuous_action_threshold) + + int(y > self.continuous_action_threshold) + ), + fire=(action[-1] > self.continuous_action_threshold), + ) + + strength = action[0] + else: + action_idx = self._action_set[action] + strength = 1.0 + # Frameskip reward = 0.0 for _ in range(frameskip): - if self.continuous: - if len(action) != 3: - raise error.Error("Actions must have 3-dimensions.") - - if isinstance(action, np.ndarray): - action = action.tolist() - r, theta, fire = action - reward += self.ale.actContinuous(r, theta, fire) - else: - reward += self.ale.act(self._action_set[action]) + reward += self.ale.act(action_idx, strength) + is_terminal = self.ale.game_over(with_truncation=False) is_truncated = self.ale.game_truncated() @@ -323,6 +345,7 @@ def _get_info(self) -> AtariEnvStepMetadata: "frame_number": self.ale.getFrameNumber(), } + @lru_cache(1) def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: """ Return keymapping -> actions for human play. @@ -358,12 +381,71 @@ def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: # (key, key, ...) -> action_idx # where action_idx is the integer value of the action enum # - return dict( - zip( - map(lambda action: tuple(sorted(mapping[action])), self._action_set), - range(len(self._action_set)), + return { + tuple(sorted(mapping[act_idx])): act_idx for act_idx in self._action_set + } + + @lru_cache(18) + def map_action_idx( + self, left_center_right: int, down_center_up: int, fire: bool + ) -> int: + """ + Return an action idx given unit actions for underlying env. + """ + # no op and fire + if left_center_right == 0 and down_center_up == 0 and not fire: + return ale_py.Action.NOOP + elif left_center_right == 0 and down_center_up == 0 and fire: + return ale_py.Action.FIRE + + # cardinal no fire + elif left_center_right == -1 and down_center_up == 0 and not fire: + return ale_py.Action.LEFT + elif left_center_right == 1 and down_center_up == 0 and not fire: + return ale_py.Action.RIGHT + elif left_center_right == 0 and down_center_up == -1 and not fire: + return ale_py.Action.DOWN + elif left_center_right == 0 and down_center_up == 1 and not fire: + return ale_py.Action.UP + + # cardinal fire + if left_center_right == -1 and down_center_up == 0 and fire: + return ale_py.Action.LEFTFIRE + elif left_center_right == 1 and down_center_up == 0 and fire: + return ale_py.Action.RIGHTFIRE + elif left_center_right == 0 and down_center_up == -1 and fire: + return ale_py.Action.DOWNFIRE + elif left_center_right == 0 and down_center_up == 1 and fire: + return ale_py.Action.UPFIRE + + # diagonal no fire + elif left_center_right == -1 and down_center_up == -1 and not fire: + return ale_py.Action.DOWNLEFT + elif left_center_right == 1 and down_center_up == -1 and not fire: + return ale_py.Action.DOWNRIGHT + elif left_center_right == -1 and down_center_up == 1 and not fire: + return ale_py.Action.UPLEFT + elif left_center_right == 1 and down_center_up == 1 and not fire: + return ale_py.Action.UPRIGHT + + # diagonal fire + elif left_center_right == -1 and down_center_up == -1 and fire: + return ale_py.Action.DOWNLEFTFIRE + elif left_center_right == 1 and down_center_up == -1 and fire: + return ale_py.Action.DOWNRIGHTFIRE + elif left_center_right == -1 and down_center_up == 1 and fire: + return ale_py.Action.UPLEFTFIRE + elif left_center_right == 1 and down_center_up == 1 and fire: + return ale_py.Action.UPRIGHTFIRE + + # just in case + else: + raise LookupError( + "Did not expect to get here, " + "expected `left_center_right` and `down_center_up` to be in {-1, 0, 1} " + "and `fire` to only be `True` or `False`. " + f"Received {left_center_right=}, {down_center_up=} and {fire=}." ) - ) def get_action_meanings(self) -> list[str]: """ diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 54e2d113e..bff58900e 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -1,3 +1,4 @@ +import itertools import warnings from itertools import product from unittest.mock import patch @@ -10,6 +11,12 @@ from gymnasium.utils.env_checker import check_env from utils import test_rom_path, tetris_env # noqa: F401 +_ACCEPTABLE_WARNING_SNIPPETS = [ + "is out of date. You should consider upgrading to version", + "we recommend using a symmetric and normalized space", + "This will error out when the continuous actions are discretized to illegal action spaces", +] + def test_roms_register(): registered_roms = [ @@ -34,14 +41,17 @@ def test_roms_register(): @pytest.mark.parametrize( - "env_id", - [ - env_id - for env_id, spec in gymnasium.registry.items() - if spec.entry_point == "ale_py.env:AtariEnv" - ], + "env_id,continuous", + itertools.product( + [ + env_id + for env_id, spec in gymnasium.registry.items() + if spec.entry_point == "ale_py.env:AtariEnv" + ], + [True, False], + ), ) -def test_check_env(env_id): +def test_check_env(env_id, continuous): if any( unsupported_game in env_id for unsupported_game in ["Warlords", "MazeCraze", "Joust", "Combat"] @@ -49,15 +59,15 @@ def test_check_env(env_id): pytest.skip(env_id) with warnings.catch_warnings(record=True) as caught_warnings: - env = gymnasium.make(env_id).unwrapped + env = gymnasium.make(env_id, continuous=continuous).unwrapped check_env(env, skip_render_check=True) env.close() for warning in caught_warnings: - if ( - "is out of date. You should consider upgrading to version" - not in warning.message.args[0] + if not any( + (snippet in warning.message.args[0]) + for snippet in _ACCEPTABLE_WARNING_SNIPPETS ): raise ValueError(warning.message.args[0]) @@ -342,28 +352,6 @@ def test_continuous_action_space(tetris_env): ) -@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) -def test_continuous_action_sample(tetris_env): - tetris_env.reset(seed=0) - for _ in range(100): - tetris_env.step(tetris_env.action_space.sample()) - - -@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) -def test_continuous_step_with_correct_dimensions(tetris_env): - tetris_env.reset(seed=0) - # Test with both regular list and numpy array. - tetris_env.step([0.0, -0.5, 0.5]) - tetris_env.step(np.array([0.0, -0.5, 0.5])) - - -@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) -def test_continuous_step_fails_with_wrong_dimensions(tetris_env): - tetris_env.reset(seed=0) - with pytest.raises(gymnasium.error.Error): - tetris_env.step([0.0, 1.0]) - - def test_gym_reset_with_infos(tetris_env): pack = tetris_env.reset(seed=0) From b4828414cd72c9bf6f9a6ba3e43ed4eb2cc4da27 Mon Sep 17 00:00:00 2001 From: Jet <38184875+jjshoots@users.noreply.github.com> Date: Mon, 12 Aug 2024 18:22:01 +0900 Subject: [PATCH 3/3] Bump to 0.10.0 (#552) Co-authored-by: Jet Co-authored-by: Mark Towers --- .github/workflows/ci.yml | 2 +- CHANGELOG.md | 65 ++++++++++++++++++++++++++++++++++++++++ vcpkg.json | 2 +- version.txt | 2 +- 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7296fde8d..59aef3911 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -289,7 +289,7 @@ jobs: - name: Build # wildcarding doesn't work for some reason, therefore, update the project version here - run: python -m pip install wheels/ale_py-0.9.1-${{ matrix.wheel-name }}.whl + run: python -m pip install wheels/ale_py-0.10.0-${{ matrix.wheel-name }}.whl - name: Install Gymnasium and pytest run: python -m pip install gymnasium>=1.0.0a2 pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c1bd7fae..9447ea4ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,71 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 0.10.0 - + +Previously in the original ALE interface, the actions are only joystick ActionEnum inputs. +Then, for games that use a paddle instead of a joystick, joystick controls are mapped into discrete actions applied to paddles, ie: +- All left actions (`LEFTDOWN`, `LEFTUP`, `LEFT...`) -> paddle left max +- All right actions (`RIGHTDOWN`, `RIGHTUP`, `RIGHT...`) -> paddle right max +- Up... etc. +- Down... etc. + +This results in loss of continuous action for paddles. +This change keeps this functionality and interface, but allows for continuous action inputs for games that allow paddle usage. + +To do that, the CPP interface has been modified. + +_Old Discrete ALE interface_ +```cpp +reward_t ALEInterface::act(Action action) +``` + +_New Mixed Discrete-Continuous ALE interface_ +```cpp +reward_t ALEInterface::act(Action action, float paddle_strength = 1.0) +``` + +Games where the paddle is not used simply have the `paddle_strength` parameter ignored. +This mirrors the real world scenario where you have a paddle connected, but the game doesn't react to it when the paddle is turned. +This maintains backwards compatibility. + +The Python interface has also been updated. + +_Old Discrete ALE Python Interface_ +```py +ale.act(action: int) +``` + +_New Mixed Discrete-Continuous ALE Python Interface_ +```py +ale.act(action: int, strength: float = 1.0) +``` + +More specifically, when continuous action space is used within an ALE gymnasium environment, discretization happens at the Python level. +```py +if continuous: + # action is expected to be a [2,] array of floats + x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) + action_idx = self.map_action_idx( + left_center_right=( + -int(x < self.continuous_action_threshold) + + int(x > self.continuous_action_threshold) + ), + down_center_up=( + -int(y < self.continuous_action_threshold) + + int(y > self.continuous_action_threshold) + ), + fire=(action[-1] > self.continuous_action_threshold), + ) + ale.act(action_idx, action[1]) +``` + +More specifically, [`self.map_action_idx`](https://github.com/Farama-Foundation/Arcade-Learning-Environment/pull/550/files#diff-057906329e72d689f1d4d9d9e3f80df11ffe74da581b29b3838a436e90841b5cR388-R447) is an `lru_cache`-ed function that takes the continuous action direction and maps it into an ActionEnum. + +## 0.9.1 - + +Added support for Numpy 2.0. + ## [0.9.0] - 2024-05-10 Previously, ALE implemented only a [Gym](https://github.com/openai/gym) based environment, however, as Gym is no longer maintained (last commit was 18 months ago). We have updated `ale-py` to use [Gymnasium](http://github.com/farama-Foundation/gymnasium) (a maintained fork of Gym) as the sole backend environment implementation. For more information on Gymnasium’s API, see their [introduction page](https://gymnasium.farama.org/main/introduction/basic_usage/). diff --git a/vcpkg.json b/vcpkg.json index 161a99d88..1520fb158 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,7 +1,7 @@ { "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg-tool/main/docs/vcpkg.schema.json", "name": "arcade-learning-environment", - "version": "0.9.1", + "version": "0.10.0", "dependencies": [ "zlib" ], diff --git a/version.txt b/version.txt index f374f6662..78bc1abd1 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.9.1 +0.10.0