From 9d3605d55720c8f3112ba18db6022f471b4aa794 Mon Sep 17 00:00:00 2001 From: Jet Date: Thu, 1 Aug 2024 14:43:08 +0900 Subject: [PATCH 01/39] step by step --- src/ale_interface.cpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/ale_interface.cpp b/src/ale_interface.cpp index eae846fdd..0e29c18ce 100644 --- a/src/ale_interface.cpp +++ b/src/ale_interface.cpp @@ -259,16 +259,9 @@ int ALEInterface::lives() { // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the // game over screen. -reward_t ALEInterface::act(Action action) { - return environment->act(action, PLAYER_B_NOOP); -} - -// Applies a continuous action to the game and returns the reward. It is the -// user's responsibility to check if the game has ended and reset -// when necessary - this method will keep pressing buttons on the -// game over screen. -reward_t ALEInterface::actContinuous(float r, float theta, float fire) { - return environment->actContinuous(r, theta, fire, 0.0, 0.0, 0.0); +// Intentionally set player B actions to 0 since we are in single player mode +reward_t ALEInterface::act(float r, float theta, float fire) { + return environment->act(r, theta, fire, 0.0, 0.0, 0.0); } // Returns the vector of modes available for the current game. From 78ad5b45bd478a4aae1a4cd380f322f1af23f81f Mon Sep 17 00:00:00 2001 From: Jet Date: Thu, 1 Aug 2024 14:54:59 +0900 Subject: [PATCH 02/39] stash, gotta get back to work --- src/ale_interface.hpp | 8 +------- src/emucore/Settings.cxx | 1 - 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 832adf592..184d1754a 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -86,13 +86,7 @@ class ALEInterface { // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the // game over screen. - reward_t act(Action action); - - // Applies a continuous action to the game and returns the reward. It is the - // user's responsibility to check if the game has ended and reset - // when necessary - this method will keep pressing buttons on the - // game over screen. - reward_t actContinuous(float r, float theta, float fire); + reward_t act(float r, float theta, float fire); // Indicates if the game has ended. bool game_over(bool with_truncation = true) const; diff --git a/src/emucore/Settings.cxx b/src/emucore/Settings.cxx index bb11a8ca0..8e9193aae 100644 --- a/src/emucore/Settings.cxx +++ b/src/emucore/Settings.cxx @@ -415,7 +415,6 @@ void Settings::setDefaultSettings() { boolSettings.insert(std::pair("send_rgb", false)); intSettings.insert(std::pair("frame_skip", 1)); floatSettings.insert(std::pair("repeat_action_probability", 0.25)); - floatSettings.insert(std::pair("continuous_action_threshold", 0.5)); stringSettings.insert(std::pair("rom_file", "")); // Whether to truncate an episode on loss of life. boolSettings.insert(std::pair("truncate_on_loss_of_life", false)); From 5ae27d9b933799d5106222bdde6c636ff9bde516 Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 3 Aug 2024 17:12:24 +0900 Subject: [PATCH 03/39] remove discrete implementation and use only continuous --- src/environment/ale_state.cpp | 335 ++-------------------------------- src/environment/ale_state.hpp | 28 +-- 2 files changed, 23 insertions(+), 340 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 9c7ae8393..73d364682 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -189,106 +189,7 @@ void ALEState::updatePaddlePositions(Event* event, int delta_left, setPaddles(event, m_left_paddle, m_right_paddle); } -void ALEState::applyActionPaddles(Event* event, int player_a_action, - int player_b_action) { - // Reset keys - resetKeys(event); - - // First compute whether we should increase or decrease the paddle position - // (for both left and right players) - int delta_left; - int delta_right; - - switch (player_a_action) { - case PLAYER_A_RIGHT: - case PLAYER_A_RIGHTFIRE: - case PLAYER_A_UPRIGHT: - case PLAYER_A_DOWNRIGHT: - case PLAYER_A_UPRIGHTFIRE: - case PLAYER_A_DOWNRIGHTFIRE: - delta_left = -PADDLE_DELTA; - break; - - case PLAYER_A_LEFT: - case PLAYER_A_LEFTFIRE: - case PLAYER_A_UPLEFT: - case PLAYER_A_DOWNLEFT: - case PLAYER_A_UPLEFTFIRE: - case PLAYER_A_DOWNLEFTFIRE: - delta_left = PADDLE_DELTA; - break; - default: - delta_left = 0; - break; - } - - switch (player_b_action) { - case PLAYER_B_RIGHT: - case PLAYER_B_RIGHTFIRE: - case PLAYER_B_UPRIGHT: - case PLAYER_B_DOWNRIGHT: - case PLAYER_B_UPRIGHTFIRE: - case PLAYER_B_DOWNRIGHTFIRE: - delta_right = -PADDLE_DELTA; - break; - - case PLAYER_B_LEFT: - case PLAYER_B_LEFTFIRE: - case PLAYER_B_UPLEFT: - case PLAYER_B_DOWNLEFT: - case PLAYER_B_UPLEFTFIRE: - case PLAYER_B_DOWNLEFTFIRE: - delta_right = PADDLE_DELTA; - break; - default: - delta_right = 0; - break; - } - - // Now update the paddle positions - updatePaddlePositions(event, delta_left, delta_right); - - // Handle reset - if (player_a_action == RESET || player_b_action == RESET) - event->set(Event::ConsoleReset, 1); - - // Now add the fire event - switch (player_a_action) { - case PLAYER_A_FIRE: - case PLAYER_A_UPFIRE: - case PLAYER_A_RIGHTFIRE: - case PLAYER_A_LEFTFIRE: - case PLAYER_A_DOWNFIRE: - case PLAYER_A_UPRIGHTFIRE: - case PLAYER_A_UPLEFTFIRE: - case PLAYER_A_DOWNRIGHTFIRE: - case PLAYER_A_DOWNLEFTFIRE: - event->set(Event::PaddleZeroFire, 1); - break; - default: - // Nothing - break; - } - - switch (player_b_action) { - case PLAYER_B_FIRE: - case PLAYER_B_UPFIRE: - case PLAYER_B_RIGHTFIRE: - case PLAYER_B_LEFTFIRE: - case PLAYER_B_DOWNFIRE: - case PLAYER_B_UPRIGHTFIRE: - case PLAYER_B_UPLEFTFIRE: - case PLAYER_B_DOWNRIGHTFIRE: - case PLAYER_B_DOWNLEFTFIRE: - event->set(Event::PaddleOneFire, 1); - break; - default: - // Nothing - break; - } -} - -void ALEState::applyActionPaddlesContinuous( +void ALEState::applyActionPaddles( Event* event, float player_a_r, float player_a_theta, float player_a_fire, float player_b_r, float player_b_theta, float player_b_fire, @@ -296,31 +197,17 @@ void ALEState::applyActionPaddlesContinuous( // Reset keys resetKeys(event); - // Convert polar coordinates to x/y position. - float a_x = player_a_r * cos(player_a_theta); - float a_y = player_a_r * sin(player_a_theta); - float b_x = player_b_r * cos(player_b_theta); - float b_y = player_b_r * sin(player_b_theta); - - // First compute whether we should increase or decrease the paddle position - // (for both left and right players) - int delta_a = 0; - if (a_x > continuous_action_threshold) { // Right action. - delta_a = -PADDLE_DELTA; - } else if (a_x < -continuous_action_threshold) { // Left action. - delta_a = PADDLE_DELTA; - } - int delta_b = 0; - if (b_x > continuous_action_threshold) { // Right action. - delta_b = -PADDLE_DELTA; - } else if (b_x < -continuous_action_threshold) { // Left action. - delta_b = PADDLE_DELTA; - } - - // Now update the paddle positions - updatePaddlePositions(event, delta_a, delta_b); + // For paddles, we actually only have one continuous action per player + // This implementation mirror's PSC's original implementation of + // continuous action space, so that paddles uses the same interface as joysticks. + updatePaddlePositions( + event, + int(player_a_r * cos(player_a_theta) * PADDLE_DELTA), + int(player_b_r * cos(player_b_theta) * PADDLE_DELTA) + ); // Now add the fire event + // Don't have to call when 0 since `reset_keys` is automatically called. if (player_a_fire > continuous_action_threshold) { event->set(Event::PaddleZeroFire, 1); } @@ -343,205 +230,7 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { event->set(Event::ConsoleRightDifficultyB, !((value & 2) >> 1)); } -void ALEState::setActionJoysticks(Event* event, int player_a_action, - int player_b_action) { - // Reset keys - resetKeys(event); - - switch (player_a_action) { - case PLAYER_A_NOOP: - break; - - case PLAYER_A_FIRE: - event->set(Event::JoystickZeroFire, 1); - break; - - case PLAYER_A_UP: - event->set(Event::JoystickZeroUp, 1); - break; - - case PLAYER_A_RIGHT: - event->set(Event::JoystickZeroRight, 1); - break; - - case PLAYER_A_LEFT: - event->set(Event::JoystickZeroLeft, 1); - break; - - case PLAYER_A_DOWN: - event->set(Event::JoystickZeroDown, 1); - break; - - case PLAYER_A_UPRIGHT: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroRight, 1); - break; - - case PLAYER_A_UPLEFT: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroLeft, 1); - break; - - case PLAYER_A_DOWNRIGHT: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroRight, 1); - break; - - case PLAYER_A_DOWNLEFT: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroLeft, 1); - break; - - case PLAYER_A_UPFIRE: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroFire, 1); - break; - - case PLAYER_A_RIGHTFIRE: - event->set(Event::JoystickZeroRight, 1); - event->set(Event::JoystickZeroFire, 1); - break; - - case PLAYER_A_LEFTFIRE: - event->set(Event::JoystickZeroLeft, 1); - event->set(Event::JoystickZeroFire, 1); - break; - - case PLAYER_A_DOWNFIRE: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroFire, 1); - break; - - case PLAYER_A_UPRIGHTFIRE: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroRight, 1); - event->set(Event::JoystickZeroFire, 1); - break; - - case PLAYER_A_UPLEFTFIRE: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroLeft, 1); - event->set(Event::JoystickZeroFire, 1); - break; - - case PLAYER_A_DOWNRIGHTFIRE: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroRight, 1); - event->set(Event::JoystickZeroFire, 1); - break; - - case PLAYER_A_DOWNLEFTFIRE: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroLeft, 1); - event->set(Event::JoystickZeroFire, 1); - break; - case RESET: - event->set(Event::ConsoleReset, 1); - break; - default: - Logger::Error << "Invalid Player A Action: " << player_a_action << "\n"; - std::exit(-1); - } - - switch (player_b_action) { - case PLAYER_B_NOOP: - break; - - case PLAYER_B_FIRE: - event->set(Event::JoystickOneFire, 1); - break; - - case PLAYER_B_UP: - event->set(Event::JoystickOneUp, 1); - break; - - case PLAYER_B_RIGHT: - event->set(Event::JoystickOneRight, 1); - break; - - case PLAYER_B_LEFT: - event->set(Event::JoystickOneLeft, 1); - break; - - case PLAYER_B_DOWN: - event->set(Event::JoystickOneDown, 1); - break; - - case PLAYER_B_UPRIGHT: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneRight, 1); - break; - - case PLAYER_B_UPLEFT: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneLeft, 1); - break; - - case PLAYER_B_DOWNRIGHT: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneRight, 1); - break; - - case PLAYER_B_DOWNLEFT: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneLeft, 1); - break; - - case PLAYER_B_UPFIRE: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneFire, 1); - break; - - case PLAYER_B_RIGHTFIRE: - event->set(Event::JoystickOneRight, 1); - event->set(Event::JoystickOneFire, 1); - break; - - case PLAYER_B_LEFTFIRE: - event->set(Event::JoystickOneLeft, 1); - event->set(Event::JoystickOneFire, 1); - break; - - case PLAYER_B_DOWNFIRE: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneFire, 1); - break; - - case PLAYER_B_UPRIGHTFIRE: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneRight, 1); - event->set(Event::JoystickOneFire, 1); - break; - - case PLAYER_B_UPLEFTFIRE: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneLeft, 1); - event->set(Event::JoystickOneFire, 1); - break; - - case PLAYER_B_DOWNRIGHTFIRE: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneRight, 1); - event->set(Event::JoystickOneFire, 1); - break; - - case PLAYER_B_DOWNLEFTFIRE: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneLeft, 1); - event->set(Event::JoystickOneFire, 1); - break; - case RESET: - event->set(Event::ConsoleReset, 1); - Logger::Info << "Sending Reset...\n"; - break; - default: - Logger::Error << "Invalid Player B Action: " << player_b_action << "\n"; - std::exit(-1); - } -} - - -void ALEState::setActionJoysticksContinuous( +void ALEState::setActionJoysticks( Event* event, float player_a_r, float player_a_theta, float player_a_fire, float player_b_r, float player_b_theta, float player_b_fire, @@ -556,6 +245,8 @@ void ALEState::setActionJoysticksContinuous( float b_y = player_b_r * sin(player_b_theta); // Go through all possible events and add them if joystick position is there. + // Original Atari 2600 doesn't have continous actions for joystick actions + // So we need to quantize here if (a_x < -continuous_action_threshold) { event->set(Event::JoystickZeroLeft, 1); } diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index 6aad84683..2a9222629 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -60,27 +60,19 @@ class ALEState { //Apply the special select action void pressSelect(stella::Event* event_obj); - /** Applies paddle actions. This actually modifies the game state by updating the paddle - * resistances. */ - void applyActionPaddles(stella::Event* event_obj, int player_a_action, - int player_b_action); /** Applies paddle continuous actions. This actually modifies the game state * by updating the paddle resistances. */ - void applyActionPaddlesContinuous( - stella::Event* event_obj, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_treshold = 0.5); + void applyActionPaddles( + stella::Event* event_obj, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + ); /** Sets the joystick events. No effect until the emulator is run forward. */ - void setActionJoysticks(stella::Event* event_obj, int player_a_action, - int player_b_action); - /** Sets the joystick events for continuous actions. No effect until the - * emulator is run forward. */ - void setActionJoysticksContinuous( - stella::Event* event_obj, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold = 0.5); + void setActionJoysticks( + stella::Event* event_obj, + float player_a_r, float player_a_theta, float player_a_fire, + float player_b_r, float player_b_theta, float player_b_fire, + ); void incrementFrame(int steps = 1); From b2af342ab99ff1b6757eca158ea96a73abbc53f7 Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 3 Aug 2024 17:28:44 +0900 Subject: [PATCH 04/39] split the thresholds --- src/environment/ale_state.cpp | 33 +++++++++++++++++++-------------- src/environment/ale_state.hpp | 6 ++++++ 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 73d364682..cb801f14a 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -161,6 +161,11 @@ void ALEState::setPaddleLimits(int paddle_min_val, int paddle_max_val) { // paddle update and the positions will be clamped to the new min/max. } +void ALEState::setActionThresholds(float joystick_discrete_threshold, float fire_discrete_threshold) { + m_joystick_threshold = joystick_discrete_threshold; + m_paddle_max = fire_discrete_threshold; +} + /* ********************************************************************* * Updates the positions of the paddles, and sets an event for * updating the corresponding paddle's resistance @@ -193,7 +198,7 @@ void ALEState::applyActionPaddles( Event* event, float player_a_r, float player_a_theta, float player_a_fire, float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold) { +) { // Reset keys resetKeys(event); @@ -208,10 +213,10 @@ void ALEState::applyActionPaddles( // Now add the fire event // Don't have to call when 0 since `reset_keys` is automatically called. - if (player_a_fire > continuous_action_threshold) { + if (player_a_fire > m_fire_threshold) { event->set(Event::PaddleZeroFire, 1); } - if (player_b_fire > continuous_action_threshold) { + if (player_b_fire > m_fire_threshold) { event->set(Event::PaddleOneFire, 1); } } @@ -234,7 +239,7 @@ void ALEState::setActionJoysticks( Event* event, float player_a_r, float player_a_theta, float player_a_fire, float player_b_r, float player_b_theta, float player_b_fire, - float continuous_action_threshold) { +) { // Reset keys resetKeys(event); @@ -247,34 +252,34 @@ void ALEState::setActionJoysticks( // Go through all possible events and add them if joystick position is there. // Original Atari 2600 doesn't have continous actions for joystick actions // So we need to quantize here - if (a_x < -continuous_action_threshold) { + if (a_x < -m_joystick_threshold) { event->set(Event::JoystickZeroLeft, 1); } - if (a_x > continuous_action_threshold) { + if (a_x > m_joystick_threshold) { event->set(Event::JoystickZeroRight, 1); } - if (a_y < -continuous_action_threshold) { + if (a_y < -m_joystick_threshold) { event->set(Event::JoystickZeroDown, 1); } - if (a_y > continuous_action_threshold) { + if (a_y > m_joystick_threshold) { event->set(Event::JoystickZeroUp, 1); } - if (player_a_fire > continuous_action_threshold) { + if (player_a_fire > m_fire_threshold) { event->set(Event::JoystickZeroFire, 1); } - if (b_x < -continuous_action_threshold) { + if (b_x < -m_joystick_threshold) { event->set(Event::JoystickOneLeft, 1); } - if (b_x > continuous_action_threshold) { + if (b_x > m_joystick_threshold) { event->set(Event::JoystickOneRight, 1); } - if (b_y < -continuous_action_threshold) { + if (b_y < -m_joystick_threshold) { event->set(Event::JoystickOneDown, 1); } - if (b_y > continuous_action_threshold) { + if (b_y > m_joystick_threshold) { event->set(Event::JoystickOneUp, 1); } - if (player_b_fire > continuous_action_threshold) { + if (player_b_fire > m_fire_threshold) { event->set(Event::JoystickOneFire, 1); } } diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index 2a9222629..f5acc0347 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -117,6 +117,9 @@ class ALEState { /** Reset key presses */ void resetKeys(stella::Event* event_obj); + /** Set the clipping thresholds for the joystick and fire buttons.*/ + void setActionThresholds(float joystick_discrete_threshold, float fire_discrete_threshold); + /** Sets the paddle to a given position */ void setPaddles(stella::Event* event_obj, int left, int right); @@ -133,6 +136,9 @@ class ALEState { void setDifficultySwitches(stella::Event* event_obj, unsigned int value); private: + float m_joystick_threshold; // Threshold for continuous to discrete clip for joystick movements + float m_fire_threshold; // Threshold for continuous to discrete clip for the fire buttons (joystick and paddle mode) + int m_left_paddle; // Current value for the left-paddle int m_right_paddle; // Current value for the right-paddle From 229f1ed7598bc2793ac50cfe54a7564645dd04cb Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 3 Aug 2024 17:29:19 +0900 Subject: [PATCH 05/39] add thresholds --- src/emucore/Settings.cxx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/emucore/Settings.cxx b/src/emucore/Settings.cxx index 8e9193aae..bd238c8ca 100644 --- a/src/emucore/Settings.cxx +++ b/src/emucore/Settings.cxx @@ -415,6 +415,8 @@ void Settings::setDefaultSettings() { boolSettings.insert(std::pair("send_rgb", false)); intSettings.insert(std::pair("frame_skip", 1)); floatSettings.insert(std::pair("repeat_action_probability", 0.25)); + floatSettings.insert(std::pair("joystick_discrete_threshold", 0.5)); + floatSettings.insert(std::pair("fire_discrete_threshold", 0.0)); stringSettings.insert(std::pair("rom_file", "")); // Whether to truncate an episode on loss of life. boolSettings.insert(std::pair("truncate_on_loss_of_life", false)); From 0355c914290d7cd580ac642da342ee7b1bd61f11 Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 3 Aug 2024 22:11:33 +0900 Subject: [PATCH 06/39] use true-to-game actions --- src/environment/ale_state.cpp | 256 +++++++++++++++++++------ src/environment/ale_state.hpp | 13 +- src/environment/stella_environment.cpp | 2 - 3 files changed, 202 insertions(+), 69 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index cb801f14a..ec7291ba3 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -196,29 +196,20 @@ void ALEState::updatePaddlePositions(Event* event, int delta_left, void ALEState::applyActionPaddles( Event* event, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, + float player_a_paddle, bool player_a_fire, + float player_b_paddle, bool player_b_fire, ) { - // Reset keys + // Reset keys (this zeros out the paddle) resetKeys(event); - // For paddles, we actually only have one continuous action per player - // This implementation mirror's PSC's original implementation of - // continuous action space, so that paddles uses the same interface as joysticks. + // send paddle position and fire updatePaddlePositions( event, - int(player_a_r * cos(player_a_theta) * PADDLE_DELTA), - int(player_b_r * cos(player_b_theta) * PADDLE_DELTA) + int(player_a_paddle * PADDLE_DELTA), + int(player_b_paddle * PADDLE_DELTA) ); - - // Now add the fire event - // Don't have to call when 0 since `reset_keys` is automatically called. - if (player_a_fire > m_fire_threshold) { - event->set(Event::PaddleZeroFire, 1); - } - if (player_b_fire > m_fire_threshold) { - event->set(Event::PaddleOneFire, 1); - } + event->set(Event::PaddleZeroFire, int(player_a_fire)); + event->set(Event::PaddleOneFire, int(player_b_fire)); } void ALEState::pressSelect(Event* event) { @@ -236,51 +227,202 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { } void ALEState::setActionJoysticks( - Event* event, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, + Event* event, + int player_a_action, + int player_b_action, ) { // Reset keys resetKeys(event); - // Convert polar coordinates to x/y position. - float a_x = player_a_r * cos(player_a_theta); - float a_y = player_a_r * sin(player_a_theta); - float b_x = player_b_r * cos(player_b_theta); - float b_y = player_b_r * sin(player_b_theta); - - // Go through all possible events and add them if joystick position is there. - // Original Atari 2600 doesn't have continous actions for joystick actions - // So we need to quantize here - if (a_x < -m_joystick_threshold) { - event->set(Event::JoystickZeroLeft, 1); - } - if (a_x > m_joystick_threshold) { - event->set(Event::JoystickZeroRight, 1); - } - if (a_y < -m_joystick_threshold) { - event->set(Event::JoystickZeroDown, 1); + switch (player_a_action) { + case PLAYER_A_NOOP: + break; + + case PLAYER_A_FIRE: + event->set(Event::JoystickZeroFire, 1); + break; + + case PLAYER_A_UP: + event->set(Event::JoystickZeroUp, 1); + break; + + case PLAYER_A_RIGHT: + event->set(Event::JoystickZeroRight, 1); + break; + + case PLAYER_A_LEFT: + event->set(Event::JoystickZeroLeft, 1); + break; + + case PLAYER_A_DOWN: + event->set(Event::JoystickZeroDown, 1); + break; + + case PLAYER_A_UPRIGHT: + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroRight, 1); + break; + + case PLAYER_A_UPLEFT: + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroLeft, 1); + break; + + case PLAYER_A_DOWNRIGHT: + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroRight, 1); + break; + + case PLAYER_A_DOWNLEFT: + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroLeft, 1); + break; + + case PLAYER_A_UPFIRE: + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroFire, 1); + break; + + case PLAYER_A_RIGHTFIRE: + event->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroFire, 1); + break; + + case PLAYER_A_LEFTFIRE: + event->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroFire, 1); + break; + + case PLAYER_A_DOWNFIRE: + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroFire, 1); + break; + + case PLAYER_A_UPRIGHTFIRE: + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroFire, 1); + break; + + case PLAYER_A_UPLEFTFIRE: + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroFire, 1); + break; + + case PLAYER_A_DOWNRIGHTFIRE: + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroFire, 1); + break; + + case PLAYER_A_DOWNLEFTFIRE: + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroFire, 1); + break; + case RESET: + event->set(Event::ConsoleReset, 1); + break; + default: + Logger::Error << "Invalid Player A Action: " << player_a_action << "\n"; + std::exit(-1); } - if (a_y > m_joystick_threshold) { - event->set(Event::JoystickZeroUp, 1); - } - if (player_a_fire > m_fire_threshold) { - event->set(Event::JoystickZeroFire, 1); - } - if (b_x < -m_joystick_threshold) { - event->set(Event::JoystickOneLeft, 1); - } - if (b_x > m_joystick_threshold) { - event->set(Event::JoystickOneRight, 1); - } - if (b_y < -m_joystick_threshold) { - event->set(Event::JoystickOneDown, 1); - } - if (b_y > m_joystick_threshold) { - event->set(Event::JoystickOneUp, 1); - } - if (player_b_fire > m_fire_threshold) { - event->set(Event::JoystickOneFire, 1); + + switch (player_b_action) { + case PLAYER_B_NOOP: + break; + + case PLAYER_B_FIRE: + event->set(Event::JoystickOneFire, 1); + break; + + case PLAYER_B_UP: + event->set(Event::JoystickOneUp, 1); + break; + + case PLAYER_B_RIGHT: + event->set(Event::JoystickOneRight, 1); + break; + + case PLAYER_B_LEFT: + event->set(Event::JoystickOneLeft, 1); + break; + + case PLAYER_B_DOWN: + event->set(Event::JoystickOneDown, 1); + break; + + case PLAYER_B_UPRIGHT: + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneRight, 1); + break; + + case PLAYER_B_UPLEFT: + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneLeft, 1); + break; + + case PLAYER_B_DOWNRIGHT: + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneRight, 1); + break; + + case PLAYER_B_DOWNLEFT: + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneLeft, 1); + break; + + case PLAYER_B_UPFIRE: + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneFire, 1); + break; + + case PLAYER_B_RIGHTFIRE: + event->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneFire, 1); + break; + + case PLAYER_B_LEFTFIRE: + event->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneFire, 1); + break; + + case PLAYER_B_DOWNFIRE: + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneFire, 1); + break; + + case PLAYER_B_UPRIGHTFIRE: + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneFire, 1); + break; + + case PLAYER_B_UPLEFTFIRE: + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneFire, 1); + break; + + case PLAYER_B_DOWNRIGHTFIRE: + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneFire, 1); + break; + + case PLAYER_B_DOWNLEFTFIRE: + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneFire, 1); + break; + case RESET: + event->set(Event::ConsoleReset, 1); + Logger::Info << "Sending Reset...\n"; + break; + default: + Logger::Error << "Invalid Player B Action: " << player_b_action << "\n"; + std::exit(-1); } } diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index f5acc0347..d6a29453e 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -64,14 +64,13 @@ class ALEState { * by updating the paddle resistances. */ void applyActionPaddles( stella::Event* event_obj, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, + float player_a_paddle, bool player_a_fire, + float player_b_paddle, bool player_b_fire, ); /** Sets the joystick events. No effect until the emulator is run forward. */ void setActionJoysticks( stella::Event* event_obj, - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, + int player_a_action, int player_b_action, ); void incrementFrame(int steps = 1); @@ -117,9 +116,6 @@ class ALEState { /** Reset key presses */ void resetKeys(stella::Event* event_obj); - /** Set the clipping thresholds for the joystick and fire buttons.*/ - void setActionThresholds(float joystick_discrete_threshold, float fire_discrete_threshold); - /** Sets the paddle to a given position */ void setPaddles(stella::Event* event_obj, int left, int right); @@ -136,9 +132,6 @@ class ALEState { void setDifficultySwitches(stella::Event* event_obj, unsigned int value); private: - float m_joystick_threshold; // Threshold for continuous to discrete clip for joystick movements - float m_fire_threshold; // Threshold for continuous to discrete clip for the fire buttons (joystick and paddle mode) - int m_left_paddle; // Current value for the left-paddle int m_right_paddle; // Current value for the right-paddle diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index 5e57c8bf9..b525cfc2b 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -76,8 +76,6 @@ StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings) m_repeat_action_probability = m_osystem->settings().getFloat("repeat_action_probability"); - m_continuous_action_threshold = - m_osystem->settings().getFloat("continuous_action_threshold"); m_frame_skip = m_osystem->settings().getInt("frame_skip"); if (m_frame_skip < 1) { From 0bc1ea34b28aebd1db3a1e7f963c272898025b1c Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 3 Aug 2024 23:53:00 +0900 Subject: [PATCH 07/39] I think... I'm happy with this interface for now --- src/environment/ale_state.cpp | 154 +++++++++++++++++++++------------- src/environment/ale_state.hpp | 6 +- 2 files changed, 99 insertions(+), 61 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index ec7291ba3..3b49a5187 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -161,11 +161,6 @@ void ALEState::setPaddleLimits(int paddle_min_val, int paddle_max_val) { // paddle update and the positions will be clamped to the new min/max. } -void ALEState::setActionThresholds(float joystick_discrete_threshold, float fire_discrete_threshold) { - m_joystick_threshold = joystick_discrete_threshold; - m_paddle_max = fire_discrete_threshold; -} - /* ********************************************************************* * Updates the positions of the paddles, and sets an event for * updating the corresponding paddle's resistance @@ -195,21 +190,102 @@ void ALEState::updatePaddlePositions(Event* event, int delta_left, } void ALEState::applyActionPaddles( - Event* event, - float player_a_paddle, bool player_a_fire, - float player_b_paddle, bool player_b_fire, + Event* event, + int player_a_action, float paddle_a_strength, + int player_b_action, float paddle_b_strength, ) { - // Reset keys (this zeros out the paddle) + // Reset keys resetKeys(event); - // send paddle position and fire - updatePaddlePositions( - event, - int(player_a_paddle * PADDLE_DELTA), - int(player_b_paddle * PADDLE_DELTA) - ); - event->set(Event::PaddleZeroFire, int(player_a_fire)); - event->set(Event::PaddleOneFire, int(player_b_fire)); + int delta_a = 0; + int delta_b = 0; + switch (player_a_action) { + case PLAYER_A_RIGHT: + case PLAYER_A_RIGHTFIRE: + case PLAYER_A_UPRIGHT: + case PLAYER_A_DOWNRIGHT: + case PLAYER_A_UPRIGHTFIRE: + case PLAYER_A_DOWNRIGHTFIRE: + delta_a = static_cast(-PADDLE_DELTA * fabs(paddle_a_strength)); + break; + + case PLAYER_A_LEFT: + case PLAYER_A_LEFTFIRE: + case PLAYER_A_UPLEFT: + case PLAYER_A_DOWNLEFT: + case PLAYER_A_UPLEFTFIRE: + case PLAYER_A_DOWNLEFTFIRE: + delta_a = static_cast(PADDLE_DELTA * fabs(paddle_a_strength)); + break; + default: + break; + } + + switch (player_b_action) { + case PLAYER_B_RIGHT: + case PLAYER_B_RIGHTFIRE: + case PLAYER_B_UPRIGHT: + case PLAYER_B_DOWNRIGHT: + case PLAYER_B_UPRIGHTFIRE: + case PLAYER_B_DOWNRIGHTFIRE: + delta_b = static_cast(-PADDLE_DELTA * fabs(paddle_b_strength)); + break; + + case PLAYER_B_LEFT: + case PLAYER_B_LEFTFIRE: + case PLAYER_B_UPLEFT: + case PLAYER_B_DOWNLEFT: + case PLAYER_B_UPLEFTFIRE: + case PLAYER_B_DOWNLEFTFIRE: + delta_b = static_cast(PADDLE_DELTA * fabs(paddle_b_strength)); + break; + default: + break; + } + + // Now update the paddle positions + updatePaddlePositions(event, delta_left, delta_right); + + // Handle reset + if (player_a_action == RESET || player_b_action == RESET) + event->set(Event::ConsoleReset, 1); + + // Now add the fire event + switch (player_a_action) { + case PLAYER_A_FIRE: + case PLAYER_A_UPFIRE: + case PLAYER_A_RIGHTFIRE: + case PLAYER_A_LEFTFIRE: + case PLAYER_A_DOWNFIRE: + case PLAYER_A_UPRIGHTFIRE: + case PLAYER_A_UPLEFTFIRE: + case PLAYER_A_DOWNRIGHTFIRE: + case PLAYER_A_DOWNLEFTFIRE: + event->set(Event::PaddleZeroFire, 1); + break; + default: + // Nothing + break; + } + + switch (player_b_action) { + case PLAYER_B_FIRE: + case PLAYER_B_UPFIRE: + case PLAYER_B_RIGHTFIRE: + case PLAYER_B_LEFTFIRE: + case PLAYER_B_DOWNFIRE: + case PLAYER_B_UPRIGHTFIRE: + case PLAYER_B_UPLEFTFIRE: + case PLAYER_B_DOWNRIGHTFIRE: + case PLAYER_B_DOWNLEFTFIRE: + event->set(Event::PaddleOneFire, 1); + break; + default: + // Nothing + break; + } +} + } void ALEState::pressSelect(Event* event) { @@ -226,96 +302,75 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { event->set(Event::ConsoleRightDifficultyB, !((value & 2) >> 1)); } -void ALEState::setActionJoysticks( - Event* event, - int player_a_action, - int player_b_action, -) { +void ALEState::setActionJoysticks(Event* event, int player_a_action, + int player_b_action) { // Reset keys resetKeys(event); - switch (player_a_action) { case PLAYER_A_NOOP: break; - case PLAYER_A_FIRE: event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_UP: event->set(Event::JoystickZeroUp, 1); break; - case PLAYER_A_RIGHT: event->set(Event::JoystickZeroRight, 1); break; - case PLAYER_A_LEFT: event->set(Event::JoystickZeroLeft, 1); break; - case PLAYER_A_DOWN: event->set(Event::JoystickZeroDown, 1); break; - case PLAYER_A_UPRIGHT: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroRight, 1); break; - case PLAYER_A_UPLEFT: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroLeft, 1); break; - case PLAYER_A_DOWNRIGHT: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroRight, 1); break; - case PLAYER_A_DOWNLEFT: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroLeft, 1); break; - case PLAYER_A_UPFIRE: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_RIGHTFIRE: event->set(Event::JoystickZeroRight, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_LEFTFIRE: event->set(Event::JoystickZeroLeft, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_DOWNFIRE: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_UPRIGHTFIRE: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroRight, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_UPLEFTFIRE: event->set(Event::JoystickZeroUp, 1); event->set(Event::JoystickZeroLeft, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_DOWNRIGHTFIRE: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroRight, 1); event->set(Event::JoystickZeroFire, 1); break; - case PLAYER_A_DOWNLEFTFIRE: event->set(Event::JoystickZeroDown, 1); event->set(Event::JoystickZeroLeft, 1); @@ -323,94 +378,77 @@ void ALEState::setActionJoysticks( break; case RESET: event->set(Event::ConsoleReset, 1); + Logger::Info << "Sending Reset...\n"; break; default: Logger::Error << "Invalid Player A Action: " << player_a_action << "\n"; std::exit(-1); } - switch (player_b_action) { case PLAYER_B_NOOP: break; - case PLAYER_B_FIRE: event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_UP: event->set(Event::JoystickOneUp, 1); break; - case PLAYER_B_RIGHT: event->set(Event::JoystickOneRight, 1); break; - case PLAYER_B_LEFT: event->set(Event::JoystickOneLeft, 1); break; - case PLAYER_B_DOWN: event->set(Event::JoystickOneDown, 1); break; - case PLAYER_B_UPRIGHT: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneRight, 1); break; - case PLAYER_B_UPLEFT: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneLeft, 1); break; - case PLAYER_B_DOWNRIGHT: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneRight, 1); break; - case PLAYER_B_DOWNLEFT: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneLeft, 1); break; - case PLAYER_B_UPFIRE: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_RIGHTFIRE: event->set(Event::JoystickOneRight, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_LEFTFIRE: event->set(Event::JoystickOneLeft, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_DOWNFIRE: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_UPRIGHTFIRE: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneRight, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_UPLEFTFIRE: event->set(Event::JoystickOneUp, 1); event->set(Event::JoystickOneLeft, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_DOWNRIGHTFIRE: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneRight, 1); event->set(Event::JoystickOneFire, 1); break; - case PLAYER_B_DOWNLEFTFIRE: event->set(Event::JoystickOneDown, 1); event->set(Event::JoystickOneLeft, 1); diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index d6a29453e..385e70bd1 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -64,11 +64,11 @@ class ALEState { * by updating the paddle resistances. */ void applyActionPaddles( stella::Event* event_obj, - float player_a_paddle, bool player_a_fire, - float player_b_paddle, bool player_b_fire, + int player_a_action, float paddle_a_strength = 1.0, + int player_b_action, float paddle_b_strength = 1.0, ); /** Sets the joystick events. No effect until the emulator is run forward. */ - void setActionJoysticks( + void applyActionJoysticks( stella::Event* event_obj, int player_a_action, int player_b_action, ); From 998e3e367b0a07575ff38d73cfa6f320fcc2ef78 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 00:04:36 +0900 Subject: [PATCH 08/39] amend stella env --- src/environment/ale_state.cpp | 12 +-- src/environment/ale_state.hpp | 15 ++- src/environment/stella_environment.cpp | 135 ++++--------------------- src/environment/stella_environment.hpp | 24 ++--- 4 files changed, 38 insertions(+), 148 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 3b49a5187..b050bb07f 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -189,11 +189,9 @@ void ALEState::updatePaddlePositions(Event* event, int delta_left, setPaddles(event, m_left_paddle, m_right_paddle); } -void ALEState::applyActionPaddles( - Event* event, - int player_a_action, float paddle_a_strength, - int player_b_action, float paddle_b_strength, -) { +void ALEState::applyActionPaddles(Event* event, + int player_a_action, float paddle_a_strength, + int player_b_action, float paddle_b_strength) { // Reset keys resetKeys(event); @@ -302,8 +300,8 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { event->set(Event::ConsoleRightDifficultyB, !((value & 2) >> 1)); } -void ALEState::setActionJoysticks(Event* event, int player_a_action, - int player_b_action) { +void ALEState::setActionJoysticks(Event* event, + int player_a_action, int player_b_action) { // Reset keys resetKeys(event); switch (player_a_action) { diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index 385e70bd1..4e4d48e57 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -62,16 +62,13 @@ class ALEState { /** Applies paddle continuous actions. This actually modifies the game state * by updating the paddle resistances. */ - void applyActionPaddles( - stella::Event* event_obj, - int player_a_action, float paddle_a_strength = 1.0, - int player_b_action, float paddle_b_strength = 1.0, - ); + void applyActionPaddles(stella::Event* event_obj, + int player_a_action, float paddle_a_strength, + int player_b_action, float paddle_b_strength); + /** Sets the joystick events. No effect until the emulator is run forward. */ - void applyActionJoysticks( - stella::Event* event_obj, - int player_a_action, int player_b_action, - ); + void applyActionJoysticks(stella::Event* event_obj, + int player_a_action, int player_b_action); void incrementFrame(int steps = 1); diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index b525cfc2b..90acc645b 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -152,8 +152,8 @@ void StellaEnvironment::noopIllegalActions(Action& player_a_action, player_b_action = (Action)PLAYER_B_NOOP; } -reward_t StellaEnvironment::act(Action player_a_action, - Action player_b_action) { +reward_t StellaEnvironment::act(Action player_a_action, float paddle_a_strength, + Action player_b_action, float paddle_b_strength) { // Total reward received as we repeat the action reward_t sum_rewards = 0; @@ -165,9 +165,11 @@ reward_t StellaEnvironment::act(Action player_a_action, // Stochastically drop actions, according to m_repeat_action_probability if (rng.nextDouble() >= m_repeat_action_probability) m_player_a_action = player_a_action; + m_paddle_a_strength = paddle_a_strength; // @todo Possibly optimize by avoiding call to rand() when player B is "off" ? if (rng.nextDouble() >= m_repeat_action_probability) m_player_b_action = player_b_action; + m_paddle_b_strength = paddle_b_strength; // If so desired, request one frame's worth of sound (this does nothing if recording // is not enabled) @@ -181,55 +183,13 @@ reward_t StellaEnvironment::act(Action player_a_action, m_screen_exporter->saveNext(m_screen); // Use the stored actions, which may or may not have changed this frame - sum_rewards += oneStepAct(m_player_a_action, m_player_b_action); + sum_rewards += oneStepAct(m_player_a_action, m_player_a_strength, + m_player_b_action, m_player_b_strength); } return std::clamp(sum_rewards, m_reward_min, m_reward_max); } -reward_t StellaEnvironment::actContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire) { - // Total reward received as we repeat the action - reward_t sum_rewards = 0; - - Random& rng = getEnvironmentRNG(); - - // Apply the same action for a given number of times... note that act() will refuse to emulate - // past the terminal state - for (size_t i = 0; i < m_frame_skip; i++) { - // Stochastically drop actions, according to m_repeat_action_probability - if (rng.nextDouble() >= m_repeat_action_probability) { - m_player_a_r = player_a_r; - m_player_a_theta = player_a_theta; - m_player_a_fire = player_a_fire; - } - // @todo Possibly optimize by avoiding call to rand() when player B is "off" ? - if (rng.nextDouble() >= m_repeat_action_probability) { - m_player_b_r = player_b_r; - m_player_b_theta = player_b_theta; - m_player_b_fire = player_b_fire; - } - - // If so desired, request one frame's worth of sound (this does nothing if recording - // is not enabled) - m_osystem->sound().recordNextFrame(); - - // Render screen if we're displaying it - m_osystem->screen().render(); - - // Similarly record screen as needed - if (m_screen_exporter.get() != NULL) - m_screen_exporter->saveNext(m_screen); - - // Use the stored actions, which may or may not have changed this frame - sum_rewards += oneStepActContinuous(m_player_a_r, m_player_a_theta, m_player_a_fire, - m_player_b_r, m_player_b_theta, m_player_b_fire); - } - - return sum_rewards; -} - /** This functions emulates a push on the reset button of the console */ void StellaEnvironment::softReset() { emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps); @@ -241,8 +201,8 @@ void StellaEnvironment::softReset() { /** Applies the given actions (e.g. updating paddle positions when the paddle is used) * and performs one simulation step in Stella. */ -reward_t StellaEnvironment::oneStepAct(Action player_a_action, - Action player_b_action) { +reward_t StellaEnvironment::oneStepAct(Action player_a_action, float paddle_a_strength, + float player_b_action, float paddle_b_strength) { // Once in a terminal state, refuse to go any further (special actions must be handled // outside of this environment; in particular reset() should be called rather than passing // RESET or SYSTEM_RESET. @@ -253,30 +213,8 @@ reward_t StellaEnvironment::oneStepAct(Action player_a_action, noopIllegalActions(player_a_action, player_b_action); // Emulate in the emulator - emulate(player_a_action, player_b_action); - // Increment the number of frames seen so far - m_state.incrementFrame(); - - return m_settings->getReward(); -} - -/** Applies the given continuous actions (e.g. updating paddle positions when - * the paddle is used) and performs one simulation step in Stella. */ -reward_t StellaEnvironment::oneStepActContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire) { - // Once in a terminal state, refuse to go any further (special actions must be handled - // outside of this environment; in particular reset() should be called rather than passing - // RESET or SYSTEM_RESET. - if (isTerminal()) - return 0; - - // Convert illegal actions into NOOPs; actions such as reset are always legal - //noopIllegalActions(player_a_action, player_b_action); - - // Emulate in the emulator - emulateContinuous(player_a_r, player_a_theta, player_a_fire, - player_b_r, player_b_theta, player_b_fire); + emulate(player_a_action, paddle_a_strength, + player_b_action, paddle_b_strength); // Increment the number of frames seen so far m_state.incrementFrame(); @@ -324,39 +262,11 @@ void StellaEnvironment::setMode(game_mode_t value) { m_state.setCurrentMode(value); } -void StellaEnvironment::emulate(Action player_a_action, Action player_b_action, - size_t num_steps) { - Event* event = m_osystem->event(); - - // Handle paddles separately: we have to manually update the paddle positions at each step - if (m_use_paddles) { - // Run emulator forward for 'num_steps' - for (size_t t = 0; t < num_steps; t++) { - // Update paddle position at every step - m_state.applyActionPaddles(event, player_a_action, player_b_action); - - m_osystem->console().mediaSource().update(); - m_settings->step(m_osystem->console().system()); - } - } else { - // In joystick mode we only need to set the action events once - m_state.setActionJoysticks(event, player_a_action, player_b_action); - - for (size_t t = 0; t < num_steps; t++) { - m_osystem->console().mediaSource().update(); - m_settings->step(m_osystem->console().system()); - } - } - - // Parse screen and RAM into their respective data structures - processScreen(); - processRAM(); -} - -void StellaEnvironment::emulateContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - size_t num_steps) { +void StellaEnvironment::emulate( + Action player_a_action, float paddle_a_strength, + Action player_b_action, float paddle_b_strength, + size_t num_steps +) { Event* event = m_osystem->event(); // Handle paddles separately: we have to manually update the paddle positions at each step @@ -364,21 +274,18 @@ void StellaEnvironment::emulateContinuous( // Run emulator forward for 'num_steps' for (size_t t = 0; t < num_steps; t++) { // Update paddle position at every step - m_state.applyActionPaddlesContinuous( - event, - player_a_r, player_a_theta, player_a_fire, - player_b_r, player_b_theta, player_b_fire, - m_continuous_action_threshold); + m_state.applyActionPaddles( + event, + player_a_action, paddle_a_strength, + player_b_action, paddle_b_strength, + ); m_osystem->console().mediaSource().update(); m_settings->step(m_osystem->console().system()); } } else { // In joystick mode we only need to set the action events once - m_state.setActionJoysticksContinuous( - event, player_a_r, player_a_theta, player_a_fire, - player_b_r, player_b_theta, player_b_fire, - m_continuous_action_threshold); + m_state.applyActionJoysticks(event, player_a_action, player_b_action); for (size_t t = 0; t < num_steps; t++) { m_osystem->console().mediaSource().update(); diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index 405fd2ca1..d6d20c716 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -58,7 +58,8 @@ class StellaEnvironment { * Note that the post-act() frame number might not correspond to the pre-act() frame * number plus the frame skip. */ - reward_t act(Action player_a_action, Action player_b_action); + reward_t act(Action player_a_action, float paddle_a_strength, + Action player_b_action, float paddle_b_strength); /** Applies the given continuous actions (e.g. updating paddle positions when * the paddle is used) and performs one simulation step in Stella. Returns the @@ -67,9 +68,6 @@ class StellaEnvironment { * number might not correspond to the pre-act() frame number plus the frame * skip. */ - reward_t actContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire); /** This functions emulates a push on the reset button of the console */ void softReset(); @@ -130,23 +128,13 @@ class StellaEnvironment { private: /** This applies an action exactly one time step. Helper function to act(). */ - reward_t oneStepAct(Action player_a_action, Action player_b_action); - - /** This applies a continuous action exactly one time step. - * Helper function to actContinuous(). - */ - reward_t oneStepActContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire); - + reward_t oneStepAct(Action player_a_action, float paddle_a_strength, + Action player_b_action, float paddle_b_strength); /** Actually emulates the emulator for a given number of steps. */ - void emulate(Action player_a_action, Action player_b_action, + void emulate(Action player_a_action, float paddle_a_strength, + Action player_b_action, float paddle_b_strength, size_t num_steps = 1); - void emulateContinuous( - float player_a_r, float player_a_theta, float player_a_fire, - float player_b_r, float player_b_theta, float player_b_fire, - size_t num_steps = 1); /** Drops illegal actions, such as the fire button in skiing. Note that this is different * from the minimal set of actions. */ From c3a91641fca0ebe23dbb2c8aaeeeab58cdbda290 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 00:05:32 +0900 Subject: [PATCH 09/39] remove redundant params --- src/emucore/Settings.cxx | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/emucore/Settings.cxx b/src/emucore/Settings.cxx index bd238c8ca..8e9193aae 100644 --- a/src/emucore/Settings.cxx +++ b/src/emucore/Settings.cxx @@ -415,8 +415,6 @@ void Settings::setDefaultSettings() { boolSettings.insert(std::pair("send_rgb", false)); intSettings.insert(std::pair("frame_skip", 1)); floatSettings.insert(std::pair("repeat_action_probability", 0.25)); - floatSettings.insert(std::pair("joystick_discrete_threshold", 0.5)); - floatSettings.insert(std::pair("fire_discrete_threshold", 0.0)); stringSettings.insert(std::pair("rom_file", "")); // Whether to truncate an episode on loss of life. boolSettings.insert(std::pair("truncate_on_loss_of_life", false)); From f25a114ec07ba2cbd31c110a9fae10a19167a6f4 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 00:10:15 +0900 Subject: [PATCH 10/39] make default parameter --- src/environment/stella_environment.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index d6d20c716..7f3b700f7 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -58,8 +58,8 @@ class StellaEnvironment { * Note that the post-act() frame number might not correspond to the pre-act() frame * number plus the frame skip. */ - reward_t act(Action player_a_action, float paddle_a_strength, - Action player_b_action, float paddle_b_strength); + reward_t act(Action player_a_action, float paddle_a_strength = 1.0, + Action player_b_action, float paddle_b_strength = 1.0); /** Applies the given continuous actions (e.g. updating paddle positions when * the paddle is used) and performs one simulation step in Stella. Returns the @@ -164,7 +164,6 @@ class StellaEnvironment { int m_max_num_frames_per_episode; // Maxmimum number of frames per episode size_t m_frame_skip; // How many frames to emulate per act() float m_repeat_action_probability; // Stochasticity of the environment - float m_continuous_action_threshold; // Continuous action threshold std::unique_ptr m_screen_exporter; // Automatic screen recorder int m_max_lives; // Maximum number of lives at the start of an episode. bool m_truncate_on_loss_of_life; // Whether to truncate episodes on loss of life. @@ -173,6 +172,7 @@ class StellaEnvironment { // The last actions taken by our players Action m_player_a_action, m_player_b_action; + float m_paddle_a_strength, m_paddle_b_strength; float m_player_a_r, m_player_b_r; float m_player_a_theta, m_player_b_theta; float m_player_a_fire, m_player_b_fire; From 10903c514222e335a97dc69ad3e36c34d793163a Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 00:28:13 +0900 Subject: [PATCH 11/39] amend interface to have default parameter at top level --- src/ale_interface.cpp | 4 ++-- src/ale_interface.hpp | 2 +- src/environment/stella_environment.hpp | 4 ++-- src/python/__init__.pyi | 5 ++--- src/python/ale_python_interface.hpp | 5 ++--- 5 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/ale_interface.cpp b/src/ale_interface.cpp index 0e29c18ce..007ec6692 100644 --- a/src/ale_interface.cpp +++ b/src/ale_interface.cpp @@ -260,8 +260,8 @@ int ALEInterface::lives() { // when necessary - this method will keep pressing buttons on the // game over screen. // Intentionally set player B actions to 0 since we are in single player mode -reward_t ALEInterface::act(float r, float theta, float fire) { - return environment->act(r, theta, fire, 0.0, 0.0, 0.0); +reward_t ALEInterface::act(Action action, float paddle_strength) { + return environment->act(action, paddle_strength, PLAYER_B_NOOP, 0.0); } // Returns the vector of modes available for the current game. diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 184d1754a..f7b4d3f1c 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -86,7 +86,7 @@ class ALEInterface { // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the // game over screen. - reward_t act(float r, float theta, float fire); + reward_t act(Action action, float paddle_strength = 1.0); // Indicates if the game has ended. bool game_over(bool with_truncation = true) const; diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index 7f3b700f7..0ae6ebabf 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -58,8 +58,8 @@ class StellaEnvironment { * Note that the post-act() frame number might not correspond to the pre-act() frame * number plus the frame skip. */ - reward_t act(Action player_a_action, float paddle_a_strength = 1.0, - Action player_b_action, float paddle_b_strength = 1.0); + reward_t act(Action player_a_action, float paddle_a_strength, + Action player_b_action, float paddle_b_strength); /** Applies the given continuous actions (e.g. updating paddle positions when * the paddle is used) and performs one simulation step in Stella. Returns the diff --git a/src/python/__init__.pyi b/src/python/__init__.pyi index cbc3f7881..4025420f0 100644 --- a/src/python/__init__.pyi +++ b/src/python/__init__.pyi @@ -102,10 +102,9 @@ class ALEState: class ALEInterface: def __init__(self) -> None: ... @overload - def act(self, action: Action) -> int: ... - def actContinuous(self, r: float, theta: float, fire: float) -> int: ... + def act(self, action: Action, paddle_strength: float = 1.0) -> int: ... @overload - def act(self, action: int) -> int: ... + def act(self, action: int, paddle_strength: float = 1.0) -> int: ... def cloneState(self, *, include_rng: bool = False) -> ALEState: ... def cloneSystemState(self) -> ALEState: ... def game_over(self, *, with_truncation: bool = True) -> bool: ... diff --git a/src/python/ale_python_interface.hpp b/src/python/ale_python_interface.hpp index 6377f8db7..52235432a 100644 --- a/src/python/ale_python_interface.hpp +++ b/src/python/ale_python_interface.hpp @@ -142,11 +142,10 @@ PYBIND11_MODULE(_ale_py, m) { .def("loadROM", &ale::ALEInterface::loadROM) .def_static("isSupportedROM", &ale::ALEPythonInterface::isSupportedROM) .def_static("isSupportedROM", &ale::ALEInterface::isSupportedROM) - .def("act", (ale::reward_t(ale::ALEPythonInterface::*)(uint32_t)) & + .def("act", (ale::reward_t(ale::ALEPythonInterface::*)(uint32_t, float)) & ale::ALEPythonInterface::act) - .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action)) & + .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action, float)) & ale::ALEInterface::act) - .def("actContinuous", &ale::ALEPythonInterface::actContinuous) .def("game_over", &ale::ALEPythonInterface::game_over, py::kw_only(), py::arg("with_truncation") = py::bool_(true)) .def("game_truncated", &ale::ALEPythonInterface::game_truncated) .def("reset_game", &ale::ALEPythonInterface::reset_game) From 84ca05508250463a10f8b66802e58cf158294f16 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 00:34:35 +0900 Subject: [PATCH 12/39] swap parameter order and implement continuous for wrappers --- src/environment/stella_environment.cpp | 12 ++++++------ src/environment/stella_environment.hpp | 12 ++++++------ src/environment/stella_environment_wrapper.cpp | 8 +++++--- src/environment/stella_environment_wrapper.hpp | 3 ++- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index 90acc645b..3b9e5951b 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -152,8 +152,8 @@ void StellaEnvironment::noopIllegalActions(Action& player_a_action, player_b_action = (Action)PLAYER_B_NOOP; } -reward_t StellaEnvironment::act(Action player_a_action, float paddle_a_strength, - Action player_b_action, float paddle_b_strength) { +reward_t StellaEnvironment::act(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength) { // Total reward received as we repeat the action reward_t sum_rewards = 0; @@ -201,8 +201,8 @@ void StellaEnvironment::softReset() { /** Applies the given actions (e.g. updating paddle positions when the paddle is used) * and performs one simulation step in Stella. */ -reward_t StellaEnvironment::oneStepAct(Action player_a_action, float paddle_a_strength, - float player_b_action, float paddle_b_strength) { +reward_t StellaEnvironment::oneStepAct(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength) { // Once in a terminal state, refuse to go any further (special actions must be handled // outside of this environment; in particular reset() should be called rather than passing // RESET or SYSTEM_RESET. @@ -263,8 +263,8 @@ void StellaEnvironment::setMode(game_mode_t value) { } void StellaEnvironment::emulate( - Action player_a_action, float paddle_a_strength, - Action player_b_action, float paddle_b_strength, + Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength, size_t num_steps ) { Event* event = m_osystem->event(); diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index 0ae6ebabf..f2a59a0b2 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -58,8 +58,8 @@ class StellaEnvironment { * Note that the post-act() frame number might not correspond to the pre-act() frame * number plus the frame skip. */ - reward_t act(Action player_a_action, float paddle_a_strength, - Action player_b_action, float paddle_b_strength); + reward_t act(Action player_a_action, Action player_b_action, + float paddle_a_strength = 1.0, float paddle_b_strength = 1.0); /** Applies the given continuous actions (e.g. updating paddle positions when * the paddle is used) and performs one simulation step in Stella. Returns the @@ -128,12 +128,12 @@ class StellaEnvironment { private: /** This applies an action exactly one time step. Helper function to act(). */ - reward_t oneStepAct(Action player_a_action, float paddle_a_strength, - Action player_b_action, float paddle_b_strength); + reward_t oneStepAct(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength); /** Actually emulates the emulator for a given number of steps. */ - void emulate(Action player_a_action, float paddle_a_strength, - Action player_b_action, float paddle_b_strength, + void emulate(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength, size_t num_steps = 1); /** Drops illegal actions, such as the fire button in skiing. Note that this is different diff --git a/src/environment/stella_environment_wrapper.cpp b/src/environment/stella_environment_wrapper.cpp index f518943e6..34fe8da7c 100644 --- a/src/environment/stella_environment_wrapper.cpp +++ b/src/environment/stella_environment_wrapper.cpp @@ -9,9 +9,11 @@ StellaEnvironmentWrapper::StellaEnvironmentWrapper( StellaEnvironment& environment) : m_environment(environment) {} -reward_t StellaEnvironmentWrapper::act(Action player_a_action, - Action player_b_action) { - return m_environment.act(player_a_action, player_b_action); +reward_t StellaEnvironmentWrapper::act(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength) { + return m_environment.act(player_a_action, player_b_action, + paddle_a_strength, paddle_b_strength + ); } void StellaEnvironmentWrapper::softReset() { m_environment.softReset(); } diff --git a/src/environment/stella_environment_wrapper.hpp b/src/environment/stella_environment_wrapper.hpp index a6f4fbe65..f1581955a 100644 --- a/src/environment/stella_environment_wrapper.hpp +++ b/src/environment/stella_environment_wrapper.hpp @@ -30,7 +30,8 @@ class StellaEnvironmentWrapper { // stella_environment.hpp. public: StellaEnvironmentWrapper(StellaEnvironment& environment); - reward_t act(Action player_a_action, Action player_b_action); + reward_t act(Action player_a_action, Action player_b_action, + float paddle_a_strength, float paddle_b_strength); void softReset(); void pressSelect(size_t num_steps = 1); stella::Random& getEnvironmentRNG(); From 9aef3d7ea9e6028311edcd80e4388e47238bbc97 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 00:36:49 +0900 Subject: [PATCH 13/39] maybe stella shouldn't use default params --- src/environment/stella_environment.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index f2a59a0b2..fc4b7278e 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -59,7 +59,7 @@ class StellaEnvironment { * number plus the frame skip. */ reward_t act(Action player_a_action, Action player_b_action, - float paddle_a_strength = 1.0, float paddle_b_strength = 1.0); + float paddle_a_strength, float paddle_b_strength); /** Applies the given continuous actions (e.g. updating paddle positions when * the paddle is used) and performs one simulation step in Stella. Returns the From 3a971e8b87d948aca5dfcf117f7616d780960a07 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 01:33:01 +0900 Subject: [PATCH 14/39] move discretization to Python --- src/environment/ale_state.cpp | 4 +- src/python/env.py | 124 ++++++++++++++++++++++++++-------- 2 files changed, 100 insertions(+), 28 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index b050bb07f..df9339f7d 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -215,6 +215,7 @@ void ALEState::applyActionPaddles(Event* event, case PLAYER_A_DOWNLEFTFIRE: delta_a = static_cast(PADDLE_DELTA * fabs(paddle_a_strength)); break; + default: break; } @@ -237,6 +238,7 @@ void ALEState::applyActionPaddles(Event* event, case PLAYER_B_DOWNLEFTFIRE: delta_b = static_cast(PADDLE_DELTA * fabs(paddle_b_strength)); break; + default: break; } @@ -300,7 +302,7 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { event->set(Event::ConsoleRightDifficultyB, !((value & 2) >> 1)); } -void ALEState::setActionJoysticks(Event* event, +void ALEState::applyActionJoysticks(Event* event, int player_a_action, int player_b_action) { // Reset keys resetKeys(event); diff --git a/src/python/env.py b/src/python/env.py index 129052fc8..2eb422757 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -1,7 +1,8 @@ from __future__ import annotations +from functools import lru_cache import sys -from typing import Any, Literal +from typing import Any, Literal, Sequence import ale_py import gymnasium @@ -142,8 +143,6 @@ def __init__( self.ale.setLoggerMode(ale_py.LoggerMode.Error) # Config sticky action prob. self.ale.setFloat("repeat_action_probability", repeat_action_probability) - # Config continuous action threshold (if using continuous actions). - self.ale.setFloat("continuous_action_threshold", continuous_action_threshold) if max_num_frames_per_episode is not None: self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode) @@ -157,23 +156,24 @@ def __init__( self.seed_game() self.load_game() + # get the set of legal actions + self._action_set = ( + self.ale.getLegalActionSet() + if full_action_space + else self.ale.getMinimalActionSet() + ) + + # action space self.continuous = continuous self.continuous_action_threshold = continuous_action_threshold if continuous: - # We don't need action_set for continuous actions. - self._action_set = None # Actions are radius, theta, and fire, where first two are the # parameters of polar coordinates. self.action_space = spaces.Box( - np.array([0, -np.pi, 0]).astype(np.float32), - np.array([+1, +np.pi, +1]).astype(np.float32), + np.array([0.0, -np.pi, 0.0]).astype(np.float32), + np.array([1.0, np.pi, 1.0]).astype(np.float32), ) # radius, theta, fire. First two are polar coordinates. else: - self._action_set = ( - self.ale.getLegalActionSet() - if full_action_space - else self.ale.getMinimalActionSet() - ) self.action_space = spaces.Discrete(len(self._action_set)) # initialize observation space @@ -239,14 +239,15 @@ def reset( # pyright: ignore[reportIncompatibleMethodOverride] def step( # pyright: ignore[reportIncompatibleMethodOverride] self, - action: int | np.ndarray, + action: int | Sequence[float], ) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: """ Perform one agent step, i.e., repeats `action` frameskip # of steps. Args: - action_ind: int | np.ndarray => Action index to execute, or numpy - array of floats if continuous. + action: int | Sequence[float] => + if `continuous=False` -> action index to execute + if `continuous=True` -> numpy array of r, theta, fire Returns: tuple[np.ndarray, float, bool, bool, Dict[str, Any]] => @@ -268,13 +269,23 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] reward = 0.0 for _ in range(frameskip): if self.continuous: - if len(action) != 3: - raise error.Error("Actions must have 3-dimensions.") - - if isinstance(action, np.ndarray): - action = action.tolist() - r, theta, fire = action - reward += self.ale.actContinuous(r, theta, fire) + # compute the x, y, fire of the joystick + assert isinstance(action, Sequence) + strength = action[0] + x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) + action = self.map_action_idx( + left_center_right=( + -(x < self.continuous_action_threshold) + +(x > self.continuous_action_threshold) + ), + down_up_center=( + -(y < self.continuous_action_threshold) + +(y > self.continuous_action_threshold) + ), + fire=(action[-1] > self.continuous_action_threshold), + ) + + reward += self.ale.act(action, strength) else: reward += self.ale.act(self._action_set[action]) is_terminal = self.ale.game_over(with_truncation=False) @@ -323,6 +334,7 @@ def _get_info(self) -> AtariEnvStepMetadata: "frame_number": self.ale.getFrameNumber(), } + @lru_cache(1) def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: """ Return keymapping -> actions for human play. @@ -358,12 +370,70 @@ def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: # (key, key, ...) -> action_idx # where action_idx is the integer value of the action enum # - return dict( - zip( - map(lambda action: tuple(sorted(mapping[action])), self._action_set), - range(len(self._action_set)), + return { + tuple(sorted(mapping[act_idx])): act_idx + for act_idx in self.action_set + } + + @lru_cache(18) + def map_action_idx(self, left_center_right: int, down_center_up: int, fire: bool) -> int: + """ + Return an action idx given unit actions for underlying env. + """ + # no op and fire + if left_center_right == 0 and down_center_up == 0 and not fire: + return ale_py.Action.NOOP + elif left_center_right == 0 and down_center_up == 0 and fire: + return ale_py.Action.FIRE + + # cardinal no fire + elif left_center_right == -1 and down_center_up == 0 and not fire: + return ale_py.Action.LEFT + elif left_center_right == 1 and down_center_up == 0 and not fire: + return ale_py.Action.RIGHT + elif left_center_right == 0 and down_center_up == -1 and not fire: + return ale_py.Action.DOWN + elif left_center_right == 0 and down_center_up == 1 and not fire: + return ale_py.Action.UP + + # cardinal fire + if left_center_right == -1 and down_center_up == 0 and fire: + return ale_py.Action.LEFTFIRE + elif left_center_right == 1 and down_center_up == 0 and fire: + return ale_py.Action.RIGHTFIRE + elif left_center_right == 0 and down_center_up == -1 and fire: + return ale_py.Action.DOWNFIRE + elif left_center_right == 0 and down_center_up == 1 and fire: + return ale_py.Action.UPFIRE + + # diagonal no fire + elif left_center_right == -1 and down_center_up == -1 and not fire: + return ale_py.Action.DOWNLEFT + elif left_center_right == 1 and down_center_up == -1 and not fire: + return ale_py.Action.DOWNRIGHT + elif left_center_right == -1 and down_center_up == 1 and not fire: + return ale_py.Action.UPLEFT + elif left_center_right == 1 and down_center_up == 1 and not fire: + return ale_py.Action.UPRIGHT + + # diagonal fire + elif left_center_right == -1 and down_center_up == -1 and fire: + return ale_py.Action.DOWNLEFTFIRE + elif left_center_right == 1 and down_center_up == -1 and fire: + return ale_py.Action.DOWNRIGHTFIRE + elif left_center_right == -1 and down_center_up == 1 and fire: + return ale_py.Action.UPLEFTFIRE + elif left_center_right == 1 and down_center_up == 1 and fire: + return ale_py.Action.UPRIGHTFIRE + + # just in case + else: + raise LookupError( + "Did not expect to get here, " + "expected `left_center_right` and `down_center_up` to be in {-1, 0, 1} " + "and `fire` to only be `True` or `False`. " + f"Received {left_center_right=}, {down_center_up=} and {fire=}." ) - ) def get_action_meanings(self) -> list[str]: """ From cef5892da4807da4399fc8849b26aaf29f5a429e Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 01:41:48 +0900 Subject: [PATCH 15/39] fix some bugs --- src/environment/stella_environment.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index 3b9e5951b..668ce2a50 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -107,7 +107,7 @@ void StellaEnvironment::reset() { int noopSteps; noopSteps = 60; - emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, noopSteps); + emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, 0.0, 0.0, noopSteps); // Reset the emulator softReset(); @@ -122,7 +122,7 @@ void StellaEnvironment::reset() { // Apply necessary actions specified by the rom itself ActionVect startingActions = m_settings->getStartingActions(); for (size_t i = 0; i < startingActions.size(); i++) { - emulate(startingActions[i], PLAYER_B_NOOP); + emulate(startingActions[i], PLAYER_B_NOOP, 0.0, 0.0); } } @@ -163,13 +163,15 @@ reward_t StellaEnvironment::act(Action player_a_action, Action player_b_action, // past the terminal state for (size_t i = 0; i < m_frame_skip; i++) { // Stochastically drop actions, according to m_repeat_action_probability - if (rng.nextDouble() >= m_repeat_action_probability) + if (rng.nextDouble() >= m_repeat_action_probability) { m_player_a_action = player_a_action; m_paddle_a_strength = paddle_a_strength; + } // @todo Possibly optimize by avoiding call to rand() when player B is "off" ? - if (rng.nextDouble() >= m_repeat_action_probability) + if (rng.nextDouble() >= m_repeat_action_probability) { m_player_b_action = player_b_action; m_paddle_b_strength = paddle_b_strength; + } // If so desired, request one frame's worth of sound (this does nothing if recording // is not enabled) @@ -183,8 +185,8 @@ reward_t StellaEnvironment::act(Action player_a_action, Action player_b_action, m_screen_exporter->saveNext(m_screen); // Use the stored actions, which may or may not have changed this frame - sum_rewards += oneStepAct(m_player_a_action, m_player_a_strength, - m_player_b_action, m_player_b_strength); + sum_rewards += oneStepAct(m_player_a_action, m_player_b_action, + m_paddle_a_strength, m_paddle_b_strength); } return std::clamp(sum_rewards, m_reward_min, m_reward_max); @@ -192,7 +194,7 @@ reward_t StellaEnvironment::act(Action player_a_action, Action player_b_action, /** This functions emulates a push on the reset button of the console */ void StellaEnvironment::softReset() { - emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps); + emulate(RESET, PLAYER_B_NOOP, 0.0, 0.0, m_num_reset_steps); // Reset previous actions to NOOP for correct action repeating m_player_a_action = PLAYER_A_NOOP; @@ -213,8 +215,8 @@ reward_t StellaEnvironment::oneStepAct(Action player_a_action, Action player_b_a noopIllegalActions(player_a_action, player_b_action); // Emulate in the emulator - emulate(player_a_action, paddle_a_strength, - player_b_action, paddle_b_strength); + emulate(player_a_action, player_b_action, + paddle_a_strength, paddle_b_strength); // Increment the number of frames seen so far m_state.incrementFrame(); @@ -250,7 +252,7 @@ void StellaEnvironment::pressSelect(size_t num_steps) { } processScreen(); processRAM(); - emulate(PLAYER_A_NOOP, PLAYER_B_NOOP); + emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, 0.0, 0.0); m_state.incrementFrame(); } From b9ac2738773b9f95eb34fef008d75accec473b15 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 01:42:45 +0900 Subject: [PATCH 16/39] fix another bug --- src/ale_interface.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ale_interface.cpp b/src/ale_interface.cpp index 007ec6692..9799bd67f 100644 --- a/src/ale_interface.cpp +++ b/src/ale_interface.cpp @@ -261,7 +261,7 @@ int ALEInterface::lives() { // game over screen. // Intentionally set player B actions to 0 since we are in single player mode reward_t ALEInterface::act(Action action, float paddle_strength) { - return environment->act(action, paddle_strength, PLAYER_B_NOOP, 0.0); + return environment->act(action, PLAYER_B_NOOP, paddle_strength, 0.0); } // Returns the vector of modes available for the current game. From 4189fb82eeffd45b60117dc530c92ee7997aea17 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 01:52:51 +0900 Subject: [PATCH 17/39] stash --- src/environment/ale_state.cpp | 142 ++++++++++++------------- src/environment/stella_environment.cpp | 2 +- src/environment/stella_environment.hpp | 2 +- 3 files changed, 73 insertions(+), 73 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index df9339f7d..33f987b27 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -244,7 +244,7 @@ void ALEState::applyActionPaddles(Event* event, } // Now update the paddle positions - updatePaddlePositions(event, delta_left, delta_right); + updatePaddlePositions(event, delta_a, delta_b); // Handle reset if (player_a_action == RESET || player_b_action == RESET) @@ -302,82 +302,82 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { event->set(Event::ConsoleRightDifficultyB, !((value & 2) >> 1)); } -void ALEState::applyActionJoysticks(Event* event, +void ALEState::applyActionJoysticks(Event* event_obj, int player_a_action, int player_b_action) { // Reset keys - resetKeys(event); + resetKeys(event_obj); switch (player_a_action) { case PLAYER_A_NOOP: break; case PLAYER_A_FIRE: - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_UP: - event->set(Event::JoystickZeroUp, 1); + event_obj->set(Event::JoystickZeroUp, 1); break; case PLAYER_A_RIGHT: - event->set(Event::JoystickZeroRight, 1); + event_obj->set(Event::JoystickZeroRight, 1); break; case PLAYER_A_LEFT: - event->set(Event::JoystickZeroLeft, 1); + event_obj->set(Event::JoystickZeroLeft, 1); break; case PLAYER_A_DOWN: - event->set(Event::JoystickZeroDown, 1); + event_obj->set(Event::JoystickZeroDown, 1); break; case PLAYER_A_UPRIGHT: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroRight, 1); + event_obj->set(Event::JoystickZeroUp, 1); + event_obj->set(Event::JoystickZeroRight, 1); break; case PLAYER_A_UPLEFT: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroLeft, 1); + event_obj->set(Event::JoystickZeroUp, 1); + event_obj->set(Event::JoystickZeroLeft, 1); break; case PLAYER_A_DOWNRIGHT: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroRight, 1); + event_obj->set(Event::JoystickZeroDown, 1); + event_obj->set(Event::JoystickZeroRight, 1); break; case PLAYER_A_DOWNLEFT: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroLeft, 1); + event_obj->set(Event::JoystickZeroDown, 1); + event_obj->set(Event::JoystickZeroLeft, 1); break; case PLAYER_A_UPFIRE: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroUp, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_RIGHTFIRE: - event->set(Event::JoystickZeroRight, 1); - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroRight, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_LEFTFIRE: - event->set(Event::JoystickZeroLeft, 1); - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroLeft, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_DOWNFIRE: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroDown, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_UPRIGHTFIRE: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroRight, 1); - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroUp, 1); + event_obj->set(Event::JoystickZeroRight, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_UPLEFTFIRE: - event->set(Event::JoystickZeroUp, 1); - event->set(Event::JoystickZeroLeft, 1); - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroUp, 1); + event_obj->set(Event::JoystickZeroLeft, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_DOWNRIGHTFIRE: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroRight, 1); - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroDown, 1); + event_obj->set(Event::JoystickZeroRight, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_DOWNLEFTFIRE: - event->set(Event::JoystickZeroDown, 1); - event->set(Event::JoystickZeroLeft, 1); - event->set(Event::JoystickZeroFire, 1); + event_obj->set(Event::JoystickZeroDown, 1); + event_obj->set(Event::JoystickZeroLeft, 1); + event_obj->set(Event::JoystickZeroFire, 1); break; case RESET: - event->set(Event::ConsoleReset, 1); + event_obj->set(Event::ConsoleReset, 1); Logger::Info << "Sending Reset...\n"; break; default: @@ -388,74 +388,74 @@ void ALEState::applyActionJoysticks(Event* event, case PLAYER_B_NOOP: break; case PLAYER_B_FIRE: - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case PLAYER_B_UP: - event->set(Event::JoystickOneUp, 1); + event_obj->set(Event::JoystickOneUp, 1); break; case PLAYER_B_RIGHT: - event->set(Event::JoystickOneRight, 1); + event_obj->set(Event::JoystickOneRight, 1); break; case PLAYER_B_LEFT: - event->set(Event::JoystickOneLeft, 1); + event_obj->set(Event::JoystickOneLeft, 1); break; case PLAYER_B_DOWN: - event->set(Event::JoystickOneDown, 1); + event_obj->set(Event::JoystickOneDown, 1); break; case PLAYER_B_UPRIGHT: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneRight, 1); + event_obj->set(Event::JoystickOneUp, 1); + event_obj->set(Event::JoystickOneRight, 1); break; case PLAYER_B_UPLEFT: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneLeft, 1); + event_obj->set(Event::JoystickOneUp, 1); + event_obj->set(Event::JoystickOneLeft, 1); break; case PLAYER_B_DOWNRIGHT: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneRight, 1); + event_obj->set(Event::JoystickOneDown, 1); + event_obj->set(Event::JoystickOneRight, 1); break; case PLAYER_B_DOWNLEFT: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneLeft, 1); + event_obj->set(Event::JoystickOneDown, 1); + event_obj->set(Event::JoystickOneLeft, 1); break; case PLAYER_B_UPFIRE: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneUp, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case PLAYER_B_RIGHTFIRE: - event->set(Event::JoystickOneRight, 1); - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneRight, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case PLAYER_B_LEFTFIRE: - event->set(Event::JoystickOneLeft, 1); - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneLeft, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case PLAYER_B_DOWNFIRE: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneDown, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case PLAYER_B_UPRIGHTFIRE: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneRight, 1); - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneUp, 1); + event_obj->set(Event::JoystickOneRight, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case PLAYER_B_UPLEFTFIRE: - event->set(Event::JoystickOneUp, 1); - event->set(Event::JoystickOneLeft, 1); - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneUp, 1); + event_obj->set(Event::JoystickOneLeft, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case PLAYER_B_DOWNRIGHTFIRE: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneRight, 1); - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneDown, 1); + event_obj->set(Event::JoystickOneRight, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case PLAYER_B_DOWNLEFTFIRE: - event->set(Event::JoystickOneDown, 1); - event->set(Event::JoystickOneLeft, 1); - event->set(Event::JoystickOneFire, 1); + event_obj->set(Event::JoystickOneDown, 1); + event_obj->set(Event::JoystickOneLeft, 1); + event_obj->set(Event::JoystickOneFire, 1); break; case RESET: - event->set(Event::ConsoleReset, 1); + event_obj->set(Event::ConsoleReset, 1); Logger::Info << "Sending Reset...\n"; break; default: diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index 668ce2a50..510a14a47 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -279,7 +279,7 @@ void StellaEnvironment::emulate( m_state.applyActionPaddles( event, player_a_action, paddle_a_strength, - player_b_action, paddle_b_strength, + player_b_action, paddle_b_strength ); m_osystem->console().mediaSource().update(); diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index fc4b7278e..f2a59a0b2 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -59,7 +59,7 @@ class StellaEnvironment { * number plus the frame skip. */ reward_t act(Action player_a_action, Action player_b_action, - float paddle_a_strength, float paddle_b_strength); + float paddle_a_strength = 1.0, float paddle_b_strength = 1.0); /** Applies the given continuous actions (e.g. updating paddle positions when * the paddle is used) and performs one simulation step in Stella. Returns the From 9dfe31494da565705ab47abf9927b571660da749 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 01:55:52 +0900 Subject: [PATCH 18/39] stash --- src/environment/ale_state.cpp | 140 +++++++++++++++++----------------- 1 file changed, 70 insertions(+), 70 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 33f987b27..481433e57 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -302,82 +302,82 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { event->set(Event::ConsoleRightDifficultyB, !((value & 2) >> 1)); } -void ALEState::applyActionJoysticks(Event* event_obj, +void ALEState::applyActionJoysticks(Event* event, int player_a_action, int player_b_action) { // Reset keys - resetKeys(event_obj); + resetKeys(event); switch (player_a_action) { case PLAYER_A_NOOP: break; case PLAYER_A_FIRE: - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_UP: - event_obj->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroUp, 1); break; case PLAYER_A_RIGHT: - event_obj->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroRight, 1); break; case PLAYER_A_LEFT: - event_obj->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroLeft, 1); break; case PLAYER_A_DOWN: - event_obj->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroDown, 1); break; case PLAYER_A_UPRIGHT: - event_obj->set(Event::JoystickZeroUp, 1); - event_obj->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroRight, 1); break; case PLAYER_A_UPLEFT: - event_obj->set(Event::JoystickZeroUp, 1); - event_obj->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroLeft, 1); break; case PLAYER_A_DOWNRIGHT: - event_obj->set(Event::JoystickZeroDown, 1); - event_obj->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroRight, 1); break; case PLAYER_A_DOWNLEFT: - event_obj->set(Event::JoystickZeroDown, 1); - event_obj->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroLeft, 1); break; case PLAYER_A_UPFIRE: - event_obj->set(Event::JoystickZeroUp, 1); - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_RIGHTFIRE: - event_obj->set(Event::JoystickZeroRight, 1); - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_LEFTFIRE: - event_obj->set(Event::JoystickZeroLeft, 1); - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_DOWNFIRE: - event_obj->set(Event::JoystickZeroDown, 1); - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_UPRIGHTFIRE: - event_obj->set(Event::JoystickZeroUp, 1); - event_obj->set(Event::JoystickZeroRight, 1); - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_UPLEFTFIRE: - event_obj->set(Event::JoystickZeroUp, 1); - event_obj->set(Event::JoystickZeroLeft, 1); - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroUp, 1); + event->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_DOWNRIGHTFIRE: - event_obj->set(Event::JoystickZeroDown, 1); - event_obj->set(Event::JoystickZeroRight, 1); - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroRight, 1); + event->set(Event::JoystickZeroFire, 1); break; case PLAYER_A_DOWNLEFTFIRE: - event_obj->set(Event::JoystickZeroDown, 1); - event_obj->set(Event::JoystickZeroLeft, 1); - event_obj->set(Event::JoystickZeroFire, 1); + event->set(Event::JoystickZeroDown, 1); + event->set(Event::JoystickZeroLeft, 1); + event->set(Event::JoystickZeroFire, 1); break; case RESET: - event_obj->set(Event::ConsoleReset, 1); + event->set(Event::ConsoleReset, 1); Logger::Info << "Sending Reset...\n"; break; default: @@ -388,74 +388,74 @@ void ALEState::applyActionJoysticks(Event* event_obj, case PLAYER_B_NOOP: break; case PLAYER_B_FIRE: - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneFire, 1); break; case PLAYER_B_UP: - event_obj->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneUp, 1); break; case PLAYER_B_RIGHT: - event_obj->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneRight, 1); break; case PLAYER_B_LEFT: - event_obj->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneLeft, 1); break; case PLAYER_B_DOWN: - event_obj->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneDown, 1); break; case PLAYER_B_UPRIGHT: - event_obj->set(Event::JoystickOneUp, 1); - event_obj->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneRight, 1); break; case PLAYER_B_UPLEFT: - event_obj->set(Event::JoystickOneUp, 1); - event_obj->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneLeft, 1); break; case PLAYER_B_DOWNRIGHT: - event_obj->set(Event::JoystickOneDown, 1); - event_obj->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneRight, 1); break; case PLAYER_B_DOWNLEFT: - event_obj->set(Event::JoystickOneDown, 1); - event_obj->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneLeft, 1); break; case PLAYER_B_UPFIRE: - event_obj->set(Event::JoystickOneUp, 1); - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneFire, 1); break; case PLAYER_B_RIGHTFIRE: - event_obj->set(Event::JoystickOneRight, 1); - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneFire, 1); break; case PLAYER_B_LEFTFIRE: - event_obj->set(Event::JoystickOneLeft, 1); - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneFire, 1); break; case PLAYER_B_DOWNFIRE: - event_obj->set(Event::JoystickOneDown, 1); - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneFire, 1); break; case PLAYER_B_UPRIGHTFIRE: - event_obj->set(Event::JoystickOneUp, 1); - event_obj->set(Event::JoystickOneRight, 1); - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneFire, 1); break; case PLAYER_B_UPLEFTFIRE: - event_obj->set(Event::JoystickOneUp, 1); - event_obj->set(Event::JoystickOneLeft, 1); - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneUp, 1); + event->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneFire, 1); break; case PLAYER_B_DOWNRIGHTFIRE: - event_obj->set(Event::JoystickOneDown, 1); - event_obj->set(Event::JoystickOneRight, 1); - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneRight, 1); + event->set(Event::JoystickOneFire, 1); break; case PLAYER_B_DOWNLEFTFIRE: - event_obj->set(Event::JoystickOneDown, 1); - event_obj->set(Event::JoystickOneLeft, 1); - event_obj->set(Event::JoystickOneFire, 1); + event->set(Event::JoystickOneDown, 1); + event->set(Event::JoystickOneLeft, 1); + event->set(Event::JoystickOneFire, 1); break; case RESET: - event_obj->set(Event::ConsoleReset, 1); + event->set(Event::ConsoleReset, 1); Logger::Info << "Sending Reset...\n"; break; default: From 3c9c8aabdf2ffb941990185da7b91f0db01fae10 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:02:12 +0900 Subject: [PATCH 19/39] fix some more bugs --- src/environment/ale_state.cpp | 2 +- src/environment/stella_environment_wrapper.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 481433e57..5cb912668 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -303,7 +303,7 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) { } void ALEState::applyActionJoysticks(Event* event, - int player_a_action, int player_b_action) { + int player_a_action, int player_b_action) { // Reset keys resetKeys(event); switch (player_a_action) { diff --git a/src/environment/stella_environment_wrapper.hpp b/src/environment/stella_environment_wrapper.hpp index f1581955a..9e97ee031 100644 --- a/src/environment/stella_environment_wrapper.hpp +++ b/src/environment/stella_environment_wrapper.hpp @@ -31,7 +31,7 @@ class StellaEnvironmentWrapper { public: StellaEnvironmentWrapper(StellaEnvironment& environment); reward_t act(Action player_a_action, Action player_b_action, - float paddle_a_strength, float paddle_b_strength); + float paddle_a_strength = 1.0, float paddle_b_strength = 1.0); void softReset(); void pressSelect(size_t num_steps = 1); stella::Random& getEnvironmentRNG(); From 86722e931f350c50df0f0ec5ebe9f8a6d1fd0742 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:04:51 +0900 Subject: [PATCH 20/39] ALWAYS the rogue curlies you gotta watch out for --- src/environment/ale_state.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 5cb912668..4b19c4622 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -286,8 +286,6 @@ void ALEState::applyActionPaddles(Event* event, } } -} - void ALEState::pressSelect(Event* event) { resetKeys(event); event->set(Event::ConsoleSelect, 1); From fae19c120ef3bde1e7c5c204e17602fe78133419 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:06:17 +0900 Subject: [PATCH 21/39] streamline --- src/environment/ale_state.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index 4e4d48e57..3c0d2a529 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -58,16 +58,16 @@ class ALEState { void resetPaddles(stella::Event*); //Apply the special select action - void pressSelect(stella::Event* event_obj); + void pressSelect(stella::Event* event); /** Applies paddle continuous actions. This actually modifies the game state * by updating the paddle resistances. */ - void applyActionPaddles(stella::Event* event_obj, + void applyActionPaddles(stella::Event* event, int player_a_action, float paddle_a_strength, int player_b_action, float paddle_b_strength); /** Sets the joystick events. No effect until the emulator is run forward. */ - void applyActionJoysticks(stella::Event* event_obj, + void applyActionJoysticks(stella::Event* event, int player_a_action, int player_b_action); void incrementFrame(int steps = 1); @@ -111,22 +111,22 @@ class ALEState { ALEState save(stella::OSystem* osystem, RomSettings* settings, std::optional rng, std::string md5); /** Reset key presses */ - void resetKeys(stella::Event* event_obj); + void resetKeys(stella::Event* event); /** Sets the paddle to a given position */ - void setPaddles(stella::Event* event_obj, int left, int right); + void setPaddles(stella::Event* event, int left, int right); /** Set the paddle min/max values */ void setPaddleLimits(int paddle_min_val, int paddle_max_val); /** Updates the paddle position by a delta amount. */ - void updatePaddlePositions(stella::Event* event_obj, int delta_x, int delta_y); + void updatePaddlePositions(stella::Event* event, int delta_x, int delta_y); /** Calculates the Paddle resistance, based on the given x val */ int calcPaddleResistance(int x_val); /** Applies the current difficulty setting, which is effectively part of the action */ - void setDifficultySwitches(stella::Event* event_obj, unsigned int value); + void setDifficultySwitches(stella::Event* event, unsigned int value); private: int m_left_paddle; // Current value for the left-paddle From 2e3d879dcc6d52d613c99c071e92eda50babb18d Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:12:30 +0900 Subject: [PATCH 22/39] fixing tests --- pyproject.toml | 2 +- src/python/env.py | 8 +++---- tests/python/test_atari_env.py | 42 ++++++++++------------------------ 3 files changed, 17 insertions(+), 35 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd8dd0d0a..250e26901 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ authors = [ {name = "Michael Bowling"}, ] maintainers = [ - { name = "Farama Foundation", email = "contact@farama.org" }, + {name = "Farama Foundation", email = "contact@farama.org"}, {name = "Jesse Farebrother", email = "jfarebro@cs.mcgill.ca"}, ] classifiers = [ diff --git a/src/python/env.py b/src/python/env.py index 2eb422757..19c906128 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -2,7 +2,7 @@ from functools import lru_cache import sys -from typing import Any, Literal, Sequence +from typing import Any, Literal import ale_py import gymnasium @@ -239,13 +239,13 @@ def reset( # pyright: ignore[reportIncompatibleMethodOverride] def step( # pyright: ignore[reportIncompatibleMethodOverride] self, - action: int | Sequence[float], + action: int | np.ndarray, ) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: """ Perform one agent step, i.e., repeats `action` frameskip # of steps. Args: - action: int | Sequence[float] => + action: int | np.ndarray => if `continuous=False` -> action index to execute if `continuous=True` -> numpy array of r, theta, fire @@ -270,7 +270,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] for _ in range(frameskip): if self.continuous: # compute the x, y, fire of the joystick - assert isinstance(action, Sequence) + assert isinstance(action, np.ndarray) strength = action[0] x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) action = self.map_action_idx( diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 54e2d113e..1aecd09b0 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -1,3 +1,4 @@ +import itertools import warnings from itertools import product from unittest.mock import patch @@ -34,14 +35,17 @@ def test_roms_register(): @pytest.mark.parametrize( - "env_id", - [ - env_id - for env_id, spec in gymnasium.registry.items() - if spec.entry_point == "ale_py.env:AtariEnv" - ], + "env_id,continuous", + itertools.product( + [ + env_id + for env_id, spec in gymnasium.registry.items() + if spec.entry_point == "ale_py.env:AtariEnv" + ], + [True, False], + ), ) -def test_check_env(env_id): +def test_check_env(env_id, continuous): if any( unsupported_game in env_id for unsupported_game in ["Warlords", "MazeCraze", "Joust", "Combat"] @@ -49,7 +53,7 @@ def test_check_env(env_id): pytest.skip(env_id) with warnings.catch_warnings(record=True) as caught_warnings: - env = gymnasium.make(env_id).unwrapped + env = gymnasium.make(env_id, continuous=continuous).unwrapped check_env(env, skip_render_check=True) env.close() @@ -342,28 +346,6 @@ def test_continuous_action_space(tetris_env): ) -@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) -def test_continuous_action_sample(tetris_env): - tetris_env.reset(seed=0) - for _ in range(100): - tetris_env.step(tetris_env.action_space.sample()) - - -@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) -def test_continuous_step_with_correct_dimensions(tetris_env): - tetris_env.reset(seed=0) - # Test with both regular list and numpy array. - tetris_env.step([0.0, -0.5, 0.5]) - tetris_env.step(np.array([0.0, -0.5, 0.5])) - - -@pytest.mark.parametrize("tetris_env", [{"continuous": True}], indirect=True) -def test_continuous_step_fails_with_wrong_dimensions(tetris_env): - tetris_env.reset(seed=0) - with pytest.raises(gymnasium.error.Error): - tetris_env.step([0.0, 1.0]) - - def test_gym_reset_with_infos(tetris_env): pack = tetris_env.reset(seed=0) From d7f37b744d77f4caf2233d8bd8df66d183463f73 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:13:14 +0900 Subject: [PATCH 23/39] make int --- src/python/env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index 19c906128..76ce84aba 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -275,12 +275,12 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) action = self.map_action_idx( left_center_right=( - -(x < self.continuous_action_threshold) - +(x > self.continuous_action_threshold) + -int(x < self.continuous_action_threshold) + +int(x > self.continuous_action_threshold) ), down_up_center=( - -(y < self.continuous_action_threshold) - +(y > self.continuous_action_threshold) + -int(y < self.continuous_action_threshold) + +int(y > self.continuous_action_threshold) ), fire=(action[-1] > self.continuous_action_threshold), ) From ea133aee01efe604f7d9f952992b5875354e314c Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:13:49 +0900 Subject: [PATCH 24/39] fix argument --- src/python/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/env.py b/src/python/env.py index 76ce84aba..bdd64b535 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -278,7 +278,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] -int(x < self.continuous_action_threshold) +int(x > self.continuous_action_threshold) ), - down_up_center=( + down_center_up=( -int(y < self.continuous_action_threshold) +int(y > self.continuous_action_threshold) ), From 444dde3121c819d6e2e3f4cb6a46d304418e30ab Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:23:40 +0900 Subject: [PATCH 25/39] passing tests --- src/python/env.py | 43 ++++++++++++++++++---------------- tests/python/test_atari_env.py | 10 +++++--- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index bdd64b535..63a854776 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -265,29 +265,32 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] else: raise error.Error(f"Invalid frameskip type: {self._frameskip}") + # action formatting + if self.continuous: + # compute the x, y, fire of the joystick + assert isinstance(action, np.ndarray) + x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) + action_idx = self.map_action_idx( + left_center_right=( + -int(x < self.continuous_action_threshold) + +int(x > self.continuous_action_threshold) + ), + down_center_up=( + -int(y < self.continuous_action_threshold) + +int(y > self.continuous_action_threshold) + ), + fire=(action[-1] > self.continuous_action_threshold), + ) + strength = action[0] + else: + action_idx = self._action_set[action] + strength = 1.0 + # Frameskip reward = 0.0 for _ in range(frameskip): - if self.continuous: - # compute the x, y, fire of the joystick - assert isinstance(action, np.ndarray) - strength = action[0] - x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) - action = self.map_action_idx( - left_center_right=( - -int(x < self.continuous_action_threshold) - +int(x > self.continuous_action_threshold) - ), - down_center_up=( - -int(y < self.continuous_action_threshold) - +int(y > self.continuous_action_threshold) - ), - fire=(action[-1] > self.continuous_action_threshold), - ) - - reward += self.ale.act(action, strength) - else: - reward += self.ale.act(self._action_set[action]) + reward += self.ale.act(self._action_set[action_idx], strength) + is_terminal = self.ale.game_over(with_truncation=False) is_truncated = self.ale.game_truncated() diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 1aecd09b0..4a5ee5cde 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -11,6 +11,10 @@ from gymnasium.utils.env_checker import check_env from utils import test_rom_path, tetris_env # noqa: F401 +_ACCEPTABLE_WARNING_SNIPPETS = [ + "is out of date. You should consider upgrading to version", + "we recommend using a symmetric and normalized space", +] def test_roms_register(): registered_roms = [ @@ -59,9 +63,9 @@ def test_check_env(env_id, continuous): env.close() for warning in caught_warnings: - if ( - "is out of date. You should consider upgrading to version" - not in warning.message.args[0] + if not any( + (snippet in warning.message.args[0]) + for snippet in _ACCEPTABLE_WARNING_SNIPPETS ): raise ValueError(warning.message.args[0]) From 2d86927828ba9a69c14d993466f58f39863bd003 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:25:16 +0900 Subject: [PATCH 26/39] fix bug --- src/python/env.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index 63a854776..99e43c603 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -270,17 +270,21 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] # compute the x, y, fire of the joystick assert isinstance(action, np.ndarray) x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) - action_idx = self.map_action_idx( - left_center_right=( - -int(x < self.continuous_action_threshold) - +int(x > self.continuous_action_threshold) - ), - down_center_up=( - -int(y < self.continuous_action_threshold) - +int(y > self.continuous_action_threshold) - ), - fire=(action[-1] > self.continuous_action_threshold), - ) + action_idx = self._action_set[ + self.map_action_idx( + left_center_right=( + -int(x < self.continuous_action_threshold) + +int(x > self.continuous_action_threshold) + ), + down_center_up=( + -int(y < self.continuous_action_threshold) + +int(y > self.continuous_action_threshold) + ), + fire=( + action[-1] > self.continuous_action_threshold + ), + ) + ] strength = action[0] else: action_idx = self._action_set[action] @@ -289,7 +293,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] # Frameskip reward = 0.0 for _ in range(frameskip): - reward += self.ale.act(self._action_set[action_idx], strength) + reward += self.ale.act(action_idx, strength) is_terminal = self.ale.game_over(with_truncation=False) is_truncated = self.ale.game_truncated() From b58e9b2ef7dc789b559753c3cadf073ae66cdd1e Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:32:02 +0900 Subject: [PATCH 27/39] use full action space in continuous mode --- src/python/env.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index 99e43c603..77e8275bb 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -3,6 +3,7 @@ from functools import lru_cache import sys from typing import Any, Literal +from warnings import warn import ale_py import gymnasium @@ -157,9 +158,15 @@ def __init__( self.load_game() # get the set of legal actions + if continuous and not full_action_space: + warn( + "`continuous` is set to `True`, but `full_action_space` is set to `False`. " + "This will error out when the continuous actions are discretized to illegal action spaces. " + "Therefore, `full_action_space` has been automatically set to `True`." + ) self._action_set = ( self.ale.getLegalActionSet() - if full_action_space + if (full_action_space or continuous) else self.ale.getMinimalActionSet() ) @@ -270,21 +277,20 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] # compute the x, y, fire of the joystick assert isinstance(action, np.ndarray) x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) - action_idx = self._action_set[ - self.map_action_idx( - left_center_right=( - -int(x < self.continuous_action_threshold) + action_idx = self.map_action_idx( + left_center_right=( + -int(x < self.continuous_action_threshold) +int(x > self.continuous_action_threshold) - ), - down_center_up=( - -int(y < self.continuous_action_threshold) + ), + down_center_up=( + -int(y < self.continuous_action_threshold) +int(y > self.continuous_action_threshold) - ), - fire=( - action[-1] > self.continuous_action_threshold - ), - ) - ] + ), + fire=( + action[-1] > self.continuous_action_threshold + ), + ) + strength = action[0] else: action_idx = self._action_set[action] From a5df8883f8207bf544b70b98f4a64d3e94089deb Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:33:44 +0900 Subject: [PATCH 28/39] precommit --- src/python/env.py | 19 ++++++++----------- tests/python/test_atari_env.py | 1 + 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/python/env.py b/src/python/env.py index 77e8275bb..dc6a7a880 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -1,7 +1,7 @@ from __future__ import annotations -from functools import lru_cache import sys +from functools import lru_cache from typing import Any, Literal from warnings import warn @@ -280,15 +280,13 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] action_idx = self.map_action_idx( left_center_right=( -int(x < self.continuous_action_threshold) - +int(x > self.continuous_action_threshold) + + int(x > self.continuous_action_threshold) ), down_center_up=( -int(y < self.continuous_action_threshold) - +int(y > self.continuous_action_threshold) - ), - fire=( - action[-1] > self.continuous_action_threshold + + int(y > self.continuous_action_threshold) ), + fire=(action[-1] > self.continuous_action_threshold), ) strength = action[0] @@ -383,13 +381,12 @@ def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: # (key, key, ...) -> action_idx # where action_idx is the integer value of the action enum # - return { - tuple(sorted(mapping[act_idx])): act_idx - for act_idx in self.action_set - } + return {tuple(sorted(mapping[act_idx])): act_idx for act_idx in self.action_set} @lru_cache(18) - def map_action_idx(self, left_center_right: int, down_center_up: int, fire: bool) -> int: + def map_action_idx( + self, left_center_right: int, down_center_up: int, fire: bool + ) -> int: """ Return an action idx given unit actions for underlying env. """ diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 4a5ee5cde..477cfc8b8 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -16,6 +16,7 @@ "we recommend using a symmetric and normalized space", ] + def test_roms_register(): registered_roms = [ env_id From ea2303f021040ba6167aea36b33c3b90ef7eb7d2 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:45:03 +0900 Subject: [PATCH 29/39] fix bug --- src/python/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/env.py b/src/python/env.py index dc6a7a880..797f9de2f 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -381,7 +381,7 @@ def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: # (key, key, ...) -> action_idx # where action_idx is the integer value of the action enum # - return {tuple(sorted(mapping[act_idx])): act_idx for act_idx in self.action_set} + return {tuple(sorted(mapping[act_idx])): act_idx for act_idx in self._action_set} @lru_cache(18) def map_action_idx( From 394c515418e414a41c2944ff33628dd3a3332c81 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:48:58 +0900 Subject: [PATCH 30/39] additional warning --- tests/python/test_atari_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/test_atari_env.py b/tests/python/test_atari_env.py index 477cfc8b8..bff58900e 100644 --- a/tests/python/test_atari_env.py +++ b/tests/python/test_atari_env.py @@ -14,6 +14,7 @@ _ACCEPTABLE_WARNING_SNIPPETS = [ "is out of date. You should consider upgrading to version", "we recommend using a symmetric and normalized space", + "This will error out when the continuous actions are discretized to illegal action spaces", ] From be3e86995a2f2e87a59c3aaa77a3d28424323554 Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 02:53:12 +0900 Subject: [PATCH 31/39] precommit --- src/python/env.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/python/env.py b/src/python/env.py index 797f9de2f..db9801bc2 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -381,7 +381,9 @@ def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: # (key, key, ...) -> action_idx # where action_idx is the integer value of the action enum # - return {tuple(sorted(mapping[act_idx])): act_idx for act_idx in self._action_set} + return { + tuple(sorted(mapping[act_idx])): act_idx for act_idx in self._action_set + } @lru_cache(18) def map_action_idx( From c38033e936c2b5cb7604a867590b152cc3ac968f Mon Sep 17 00:00:00 2001 From: Jet Date: Sun, 4 Aug 2024 13:35:36 +0900 Subject: [PATCH 32/39] update interface signature --- src/python/ale_python_interface.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/python/ale_python_interface.hpp b/src/python/ale_python_interface.hpp index 52235432a..d6cb5920a 100644 --- a/src/python/ale_python_interface.hpp +++ b/src/python/ale_python_interface.hpp @@ -142,8 +142,12 @@ PYBIND11_MODULE(_ale_py, m) { .def("loadROM", &ale::ALEInterface::loadROM) .def_static("isSupportedROM", &ale::ALEPythonInterface::isSupportedROM) .def_static("isSupportedROM", &ale::ALEInterface::isSupportedROM) + .def("act", (ale::reward_t(ale::ALEPythonInterface::*)(uint32_t)) & + ale::ALEPythonInterface::act) .def("act", (ale::reward_t(ale::ALEPythonInterface::*)(uint32_t, float)) & ale::ALEPythonInterface::act) + .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action)) & + ale::ALEInterface::act) .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action, float)) & ale::ALEInterface::act) .def("game_over", &ale::ALEPythonInterface::game_over, py::kw_only(), py::arg("with_truncation") = py::bool_(true)) From 140f95ea5738b3f29f6be5046bcc987e4f64fe14 Mon Sep 17 00:00:00 2001 From: Jet Date: Tue, 6 Aug 2024 13:29:50 +0900 Subject: [PATCH 33/39] change to default emulate strength 1.0 --- src/environment/stella_environment.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index 510a14a47..3c16e9f9e 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -107,7 +107,7 @@ void StellaEnvironment::reset() { int noopSteps; noopSteps = 60; - emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, 0.0, 0.0, noopSteps); + emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, 1.0, 1.0, noopSteps); // Reset the emulator softReset(); @@ -122,7 +122,7 @@ void StellaEnvironment::reset() { // Apply necessary actions specified by the rom itself ActionVect startingActions = m_settings->getStartingActions(); for (size_t i = 0; i < startingActions.size(); i++) { - emulate(startingActions[i], PLAYER_B_NOOP, 0.0, 0.0); + emulate(startingActions[i], PLAYER_B_NOOP, 1.0, 1.0); } } @@ -194,7 +194,7 @@ reward_t StellaEnvironment::act(Action player_a_action, Action player_b_action, /** This functions emulates a push on the reset button of the console */ void StellaEnvironment::softReset() { - emulate(RESET, PLAYER_B_NOOP, 0.0, 0.0, m_num_reset_steps); + emulate(RESET, PLAYER_B_NOOP, 1.0, 1.0, m_num_reset_steps); // Reset previous actions to NOOP for correct action repeating m_player_a_action = PLAYER_A_NOOP; @@ -252,7 +252,7 @@ void StellaEnvironment::pressSelect(size_t num_steps) { } processScreen(); processRAM(); - emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, 0.0, 0.0); + emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, 1.0, 1.0); m_state.incrementFrame(); } From d46de36dc5c24bb203b14e742a4e54c42540032f Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 10 Aug 2024 13:24:44 +0900 Subject: [PATCH 34/39] perform a version bump --- .github/workflows/ci.yml | 2 +- CHANGELOG.md | 4 ++++ vcpkg.json | 2 +- version.txt | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7296fde8d..59aef3911 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -289,7 +289,7 @@ jobs: - name: Build # wildcarding doesn't work for some reason, therefore, update the project version here - run: python -m pip install wheels/ale_py-0.9.1-${{ matrix.wheel-name }}.whl + run: python -m pip install wheels/ale_py-0.10.0-${{ matrix.wheel-name }}.whl - name: Install Gymnasium and pytest run: python -m pip install gymnasium>=1.0.0a2 pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c1bd7fae..18a646d04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 0.10.0 - + + + ## [0.9.0] - 2024-05-10 Previously, ALE implemented only a [Gym](https://github.com/openai/gym) based environment, however, as Gym is no longer maintained (last commit was 18 months ago). We have updated `ale-py` to use [Gymnasium](http://github.com/farama-Foundation/gymnasium) (a maintained fork of Gym) as the sole backend environment implementation. For more information on Gymnasium’s API, see their [introduction page](https://gymnasium.farama.org/main/introduction/basic_usage/). diff --git a/vcpkg.json b/vcpkg.json index 161a99d88..1520fb158 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,7 +1,7 @@ { "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg-tool/main/docs/vcpkg.schema.json", "name": "arcade-learning-environment", - "version": "0.9.1", + "version": "0.10.0", "dependencies": [ "zlib" ], diff --git a/version.txt b/version.txt index f374f6662..78bc1abd1 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.9.1 +0.10.0 From edfa24f522f4d9639b38b371f55343a5d2e40727 Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 10 Aug 2024 13:27:34 +0900 Subject: [PATCH 35/39] 0.11 --- .github/workflows/ci.yml | 2 +- CHANGELOG.md | 2 +- vcpkg.json | 2 +- version.txt | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 59aef3911..a1a473edb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -289,7 +289,7 @@ jobs: - name: Build # wildcarding doesn't work for some reason, therefore, update the project version here - run: python -m pip install wheels/ale_py-0.10.0-${{ matrix.wheel-name }}.whl + run: python -m pip install wheels/ale_py-0.11.0-${{ matrix.wheel-name }}.whl - name: Install Gymnasium and pytest run: python -m pip install gymnasium>=1.0.0a2 pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index 18a646d04..2ada33a13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## 0.10.0 - +## 0.11.0 - diff --git a/vcpkg.json b/vcpkg.json index 1520fb158..bd116e9d3 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,7 +1,7 @@ { "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg-tool/main/docs/vcpkg.schema.json", "name": "arcade-learning-environment", - "version": "0.10.0", + "version": "0.11.0", "dependencies": [ "zlib" ], diff --git a/version.txt b/version.txt index 78bc1abd1..d9df1bbc0 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.10.0 +0.11.0 From 479d34246545bf4892e5a28d380fb10fdcd00c02 Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 10 Aug 2024 13:33:39 +0900 Subject: [PATCH 36/39] add changelog --- CHANGELOG.md | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ada33a13..1681e6e8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,69 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## 0.11.0 - - +Previously in the original ALE interface, the actions are only joystick ActionEnum inputs. +Then, for games that use a paddle instead of a joystick, joystick controls are mapped into discrete actions applied to paddles, ie: +- All left actions (`LEFTDOWN`, `LEFTUP`, `LEFT...`) -> paddle left max +- All right actions (`RIGHTDOWN`, `RIGHTUP`, `RIGHT...`) -> paddle right max +- Up... etc. +- Down... etc. + +This results in loss of continuous action for paddles. +This change keeps this functionality and interface, but allows for continuous action inputs for games that allow paddle usage. + +To do that, the CPP interface has been modified. + +_Old Discrete ALE interface_ +```cpp +reward_t ALEInterface::act(Action action) +``` + +_New Mixed Discrete-Continuous ALE interface_ +```cpp +reward_t ALEInterface::act(Action action, float paddle_strength = 1.0) +``` + +For games that utilize paddles, if the paddle strength parameter is set (the default value is 1.0), we pass the paddle action to the underlying game via [this change](https://github.com/Farama-Foundation/Arcade-Learning-Environment/pull/550/files#diff-6d221bfa0361147785924bb8dbd7176abb4727e0d2720cfdda63b5bd6c8fbdefR207): +```cpp +delta_a = static_cast(-PADDLE_DELTA * fabs(paddle_a_strength)); +``` + +This maintains backwards compatibility (it performs exactly the same if `paddle_x_strength` is not applied). +For games where the paddle is not used, the `paddle_x_strength` parameter is just ignored. +This mirrors the real world scenario where you have a paddle connected, but the game doesn't react to it when the paddle is turned. +The Python interface has also been updated. + +_Old Discrete ALE Python Interface_ +```py +ale.act(action: int) +``` + +_New Mixed Discrete-Continuous ALE Python Interface_ +```py +ale.act(action: int, strength: float = 1.0) +``` + +The main change this PR applies over the original CALE implementation is that the discretization is now handled at the Python level. +More specifically, when continuous action space is used. +```py +if continuous: + # action is expected to be a [2,] array of floats + x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1]) + action_idx = self.map_action_idx( + left_center_right=( + -int(x < self.continuous_action_threshold) + + int(x > self.continuous_action_threshold) + ), + down_center_up=( + -int(y < self.continuous_action_threshold) + + int(y > self.continuous_action_threshold) + ), + fire=(action[-1] > self.continuous_action_threshold), + ) + ale.act(action_idx, action[1]) +``` + +More specifically, [`self.map_action_idx`](https://github.com/Farama-Foundation/Arcade-Learning-Environment/pull/550/files#diff-057906329e72d689f1d4d9d9e3f80df11ffe74da581b29b3838a436e90841b5cR388-R447) is an `lru_cache`-ed function that takes the continuous action direction and maps it into an ActionEnum. ## [0.9.0] - 2024-05-10 From 8942b5f29b341075fc27f0a1f2400e1b6a9e5c91 Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 10 Aug 2024 13:36:18 +0900 Subject: [PATCH 37/39] less verbose changelog --- CHANGELOG.md | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1681e6e8a..a84c21864 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,14 +29,10 @@ _New Mixed Discrete-Continuous ALE interface_ reward_t ALEInterface::act(Action action, float paddle_strength = 1.0) ``` -For games that utilize paddles, if the paddle strength parameter is set (the default value is 1.0), we pass the paddle action to the underlying game via [this change](https://github.com/Farama-Foundation/Arcade-Learning-Environment/pull/550/files#diff-6d221bfa0361147785924bb8dbd7176abb4727e0d2720cfdda63b5bd6c8fbdefR207): -```cpp -delta_a = static_cast(-PADDLE_DELTA * fabs(paddle_a_strength)); -``` - -This maintains backwards compatibility (it performs exactly the same if `paddle_x_strength` is not applied). -For games where the paddle is not used, the `paddle_x_strength` parameter is just ignored. +Games where the paddle is not used simply have the `paddle_strength` parameter ignored. This mirrors the real world scenario where you have a paddle connected, but the game doesn't react to it when the paddle is turned. +This maintains backwards compatibility. + The Python interface has also been updated. _Old Discrete ALE Python Interface_ @@ -49,8 +45,7 @@ _New Mixed Discrete-Continuous ALE Python Interface_ ale.act(action: int, strength: float = 1.0) ``` -The main change this PR applies over the original CALE implementation is that the discretization is now handled at the Python level. -More specifically, when continuous action space is used. +More specifically, when continuous action space is used within an ALE gymnasium environment, discretization happens at the Python level. ```py if continuous: # action is expected to be a [2,] array of floats From 9a8c8dbef0e8fcde9e8e0719611fe9928da79f31 Mon Sep 17 00:00:00 2001 From: Jet Date: Mon, 12 Aug 2024 16:30:22 +0900 Subject: [PATCH 38/39] update change log and use 0.10.0 --- CHANGELOG.md | 6 +++++- vcpkg.json | 2 +- version.txt | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a84c21864..9447ea4ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## 0.11.0 - +## 0.10.0 - Previously in the original ALE interface, the actions are only joystick ActionEnum inputs. Then, for games that use a paddle instead of a joystick, joystick controls are mapped into discrete actions applied to paddles, ie: @@ -66,6 +66,10 @@ if continuous: More specifically, [`self.map_action_idx`](https://github.com/Farama-Foundation/Arcade-Learning-Environment/pull/550/files#diff-057906329e72d689f1d4d9d9e3f80df11ffe74da581b29b3838a436e90841b5cR388-R447) is an `lru_cache`-ed function that takes the continuous action direction and maps it into an ActionEnum. +## 0.9.1 - + +Added support for Numpy 2.0. + ## [0.9.0] - 2024-05-10 Previously, ALE implemented only a [Gym](https://github.com/openai/gym) based environment, however, as Gym is no longer maintained (last commit was 18 months ago). We have updated `ale-py` to use [Gymnasium](http://github.com/farama-Foundation/gymnasium) (a maintained fork of Gym) as the sole backend environment implementation. For more information on Gymnasium’s API, see their [introduction page](https://gymnasium.farama.org/main/introduction/basic_usage/). diff --git a/vcpkg.json b/vcpkg.json index bd116e9d3..1520fb158 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,7 +1,7 @@ { "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg-tool/main/docs/vcpkg.schema.json", "name": "arcade-learning-environment", - "version": "0.11.0", + "version": "0.10.0", "dependencies": [ "zlib" ], diff --git a/version.txt b/version.txt index d9df1bbc0..78bc1abd1 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.11.0 +0.10.0 From 9cc9f72cacdb6225d3bbcded167bf3ca4f1c8582 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 12 Aug 2024 09:43:44 +0100 Subject: [PATCH 39/39] Update ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1a473edb..59aef3911 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -289,7 +289,7 @@ jobs: - name: Build # wildcarding doesn't work for some reason, therefore, update the project version here - run: python -m pip install wheels/ale_py-0.11.0-${{ matrix.wheel-name }}.whl + run: python -m pip install wheels/ale_py-0.10.0-${{ matrix.wheel-name }}.whl - name: Install Gymnasium and pytest run: python -m pip install gymnasium>=1.0.0a2 pytest