Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for continuous actions in the ALE (CALE) #539

Merged
merged 15 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ale_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ reward_t ALEInterface::act(Action action) {
return environment->act(action, PLAYER_B_NOOP);
}

// Applies a continuous action to the game and returns the reward. It is the
// user's responsibility to check if the game has ended and reset
// when necessary - this method will keep pressing buttons on the
// game over screen.
reward_t ALEInterface::actContinuous(float r, float theta, float fire) {
return environment->actContinuous(r, theta, fire, 0.0, 0.0, 0.0);
}

// Returns the vector of modes available for the current game.
// This should be called only after the rom is loaded.
ModeVect ALEInterface::getAvailableModes() const {
Expand Down
6 changes: 6 additions & 0 deletions src/ale_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class ALEInterface {
// game over screen.
reward_t act(Action action);

// Applies a continuous action to the game and returns the reward. It is the
// user's responsibility to check if the game has ended and reset
// when necessary - this method will keep pressing buttons on the
// game over screen.
reward_t actContinuous(float r, float theta, float fire);

// Indicates if the game has ended.
bool game_over(bool with_truncation = true) const;

Expand Down
1 change: 1 addition & 0 deletions src/emucore/Settings.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ void Settings::setDefaultSettings() {
boolSettings.insert(std::pair<std::string, bool>("send_rgb", false));
intSettings.insert(std::pair<std::string, int>("frame_skip", 1));
floatSettings.insert(std::pair<std::string, float>("repeat_action_probability", 0.25));
floatSettings.insert(std::pair<std::string, float>("continuous_action_threshold", 0.5));
stringSettings.insert(std::pair<std::string, std::string>("rom_file", ""));
// Whether to truncate an episode on loss of life.
boolSettings.insert(std::pair<std::string, bool>("truncate_on_loss_of_life", false));
Expand Down
90 changes: 90 additions & 0 deletions src/environment/ale_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "environment/ale_state.hpp"

#include <cassert>
#include <cmath>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -287,6 +288,47 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action,
}
}

void ALEState::applyActionPaddlesContinuous(
Event* event,
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire,
float continuous_action_threshold) {
// Reset keys
resetKeys(event);

// Convert polar coordinates to x/y position.
float a_x = player_a_r * cos(player_a_theta);
float a_y = player_a_r * sin(player_a_theta);
float b_x = player_b_r * cos(player_b_theta);
float b_y = player_b_r * sin(player_b_theta);

// First compute whether we should increase or decrease the paddle position
// (for both left and right players)
int delta_a = 0;
if (a_x > continuous_action_threshold) { // Right action.
delta_a = -PADDLE_DELTA;
} else if (a_x < -continuous_action_threshold) { // Left action.
delta_a = PADDLE_DELTA;
}
int delta_b = 0;
if (b_x > continuous_action_threshold) { // Right action.
delta_b = -PADDLE_DELTA;
} else if (b_x < -continuous_action_threshold) { // Left action.
delta_b = PADDLE_DELTA;
}

// Now update the paddle positions
updatePaddlePositions(event, delta_a, delta_b);

// Now add the fire event
if (player_a_fire > continuous_action_threshold) {
event->set(Event::PaddleZeroFire, 1);
}
if (player_b_fire > continuous_action_threshold) {
event->set(Event::PaddleOneFire, 1);
}
}

void ALEState::pressSelect(Event* event) {
resetKeys(event);
event->set(Event::ConsoleSelect, 1);
Expand Down Expand Up @@ -498,6 +540,54 @@ void ALEState::setActionJoysticks(Event* event, int player_a_action,
}
}


void ALEState::setActionJoysticksContinuous(
Event* event,
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire,
float continuous_action_threshold) {
// Reset keys
resetKeys(event);

// Convert polar coordinates to x/y position.
float a_x = player_a_r * cos(player_a_theta);
float a_y = player_a_r * sin(player_a_theta);
float b_x = player_b_r * cos(player_b_theta);
float b_y = player_b_r * sin(player_b_theta);

// Go through all possible events and add them if joystick position is there.
if (a_x < -continuous_action_threshold) {
event->set(Event::JoystickZeroLeft, 1);
}
if (a_x > continuous_action_threshold) {
event->set(Event::JoystickZeroRight, 1);
}
if (a_y < -continuous_action_threshold) {
event->set(Event::JoystickZeroDown, 1);
}
if (a_y > continuous_action_threshold) {
event->set(Event::JoystickZeroUp, 1);
}
if (player_a_fire > continuous_action_threshold) {
event->set(Event::JoystickZeroFire, 1);
}
if (b_x < -continuous_action_threshold) {
event->set(Event::JoystickOneLeft, 1);
}
if (b_x > continuous_action_threshold) {
event->set(Event::JoystickOneRight, 1);
}
if (b_y < -continuous_action_threshold) {
event->set(Event::JoystickOneDown, 1);
}
if (b_y > continuous_action_threshold) {
event->set(Event::JoystickOneUp, 1);
}
if (player_b_fire > continuous_action_threshold) {
event->set(Event::JoystickOneFire, 1);
}
}

/* ***************************************************************************
Function resetKeys
Unpresses all control-relevant keys
Expand Down
14 changes: 14 additions & 0 deletions src/environment/ale_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,23 @@ class ALEState {
* resistances. */
void applyActionPaddles(stella::Event* event_obj, int player_a_action,
int player_b_action);
/** Applies paddle continuous actions. This actually modifies the game state
* by updating the paddle resistances. */
void applyActionPaddlesContinuous(
stella::Event* event_obj,
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire,
float continuous_action_treshold = 0.5);
/** Sets the joystick events. No effect until the emulator is run forward. */
void setActionJoysticks(stella::Event* event_obj, int player_a_action,
int player_b_action);
/** Sets the joystick events for continuous actions. No effect until the
* emulator is run forward. */
void setActionJoysticksContinuous(
stella::Event* event_obj,
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire,
float continuous_action_threshold = 0.5);

void incrementFrame(int steps = 1);

Expand Down
106 changes: 106 additions & 0 deletions src/environment/stella_environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings)

m_repeat_action_probability =
m_osystem->settings().getFloat("repeat_action_probability");
m_continuous_action_threshold =
m_osystem->settings().getFloat("continuous_action_threshold");

m_frame_skip = m_osystem->settings().getInt("frame_skip");
if (m_frame_skip < 1) {
Expand Down Expand Up @@ -187,6 +189,49 @@ reward_t StellaEnvironment::act(Action player_a_action,
return std::clamp(sum_rewards, m_reward_min, m_reward_max);
}

reward_t StellaEnvironment::actContinuous(
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire) {
// Total reward received as we repeat the action
reward_t sum_rewards = 0;

Random& rng = getEnvironmentRNG();

// Apply the same action for a given number of times... note that act() will refuse to emulate
// past the terminal state
for (size_t i = 0; i < m_frame_skip; i++) {
// Stochastically drop actions, according to m_repeat_action_probability
if (rng.nextDouble() >= m_repeat_action_probability) {
m_player_a_r = player_a_r;
m_player_a_theta = player_a_theta;
m_player_a_fire = player_a_fire;
}
// @todo Possibly optimize by avoiding call to rand() when player B is "off" ?
if (rng.nextDouble() >= m_repeat_action_probability) {
m_player_b_r = player_b_r;
m_player_b_theta = player_b_theta;
m_player_b_fire = player_b_fire;
}

// If so desired, request one frame's worth of sound (this does nothing if recording
// is not enabled)
m_osystem->sound().recordNextFrame();

// Render screen if we're displaying it
m_osystem->screen().render();

// Similarly record screen as needed
if (m_screen_exporter.get() != NULL)
m_screen_exporter->saveNext(m_screen);

// Use the stored actions, which may or may not have changed this frame
sum_rewards += oneStepActContinuous(m_player_a_r, m_player_a_theta, m_player_a_fire,
m_player_b_r, m_player_b_theta, m_player_b_fire);
}

return sum_rewards;
}

/** This functions emulates a push on the reset button of the console */
void StellaEnvironment::softReset() {
emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps);
Expand Down Expand Up @@ -217,6 +262,29 @@ reward_t StellaEnvironment::oneStepAct(Action player_a_action,
return m_settings->getReward();
}

/** Applies the given continuous actions (e.g. updating paddle positions when
* the paddle is used) and performs one simulation step in Stella. */
reward_t StellaEnvironment::oneStepActContinuous(
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire) {
// Once in a terminal state, refuse to go any further (special actions must be handled
// outside of this environment; in particular reset() should be called rather than passing
// RESET or SYSTEM_RESET.
if (isTerminal())
return 0;

// Convert illegal actions into NOOPs; actions such as reset are always legal
//noopIllegalActions(player_a_action, player_b_action);

// Emulate in the emulator
emulateContinuous(player_a_r, player_a_theta, player_a_fire,
player_b_r, player_b_theta, player_b_fire);
// Increment the number of frames seen so far
m_state.incrementFrame();

return m_settings->getReward();
}

bool StellaEnvironment::isTerminal() const {
return isGameTerminal() || isGameTruncated();
}
Expand Down Expand Up @@ -287,6 +355,44 @@ void StellaEnvironment::emulate(Action player_a_action, Action player_b_action,
processRAM();
}

void StellaEnvironment::emulateContinuous(
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire,
size_t num_steps) {
Event* event = m_osystem->event();

// Handle paddles separately: we have to manually update the paddle positions at each step
if (m_use_paddles) {
// Run emulator forward for 'num_steps'
for (size_t t = 0; t < num_steps; t++) {
// Update paddle position at every step
m_state.applyActionPaddlesContinuous(
event,
player_a_r, player_a_theta, player_a_fire,
player_b_r, player_b_theta, player_b_fire,
m_continuous_action_threshold);

m_osystem->console().mediaSource().update();
m_settings->step(m_osystem->console().system());
}
} else {
// In joystick mode we only need to set the action events once
m_state.setActionJoysticksContinuous(
event, player_a_r, player_a_theta, player_a_fire,
player_b_r, player_b_theta, player_b_fire,
m_continuous_action_threshold);

for (size_t t = 0; t < num_steps; t++) {
m_osystem->console().mediaSource().update();
m_settings->step(m_osystem->console().system());
}
}

// Parse screen and RAM into their respective data structures
processScreen();
processRAM();
}

/** Accessor methods for the environment state. */
void StellaEnvironment::setState(const ALEState& state) { m_state = state; }

Expand Down
27 changes: 27 additions & 0 deletions src/environment/stella_environment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ class StellaEnvironment {
*/
reward_t act(Action player_a_action, Action player_b_action);

/** Applies the given continuous actions (e.g. updating paddle positions when
* the paddle is used) and performs one simulation step in Stella. Returns the
* resultant reward. When frame skip is set to > 1, up the corresponding
* number of simulation steps are performed. Note that the post-act() frame
* number might not correspond to the pre-act() frame number plus the frame
* skip.
*/
reward_t actContinuous(
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire);

/** This functions emulates a push on the reset button of the console */
void softReset();

Expand Down Expand Up @@ -121,9 +132,21 @@ class StellaEnvironment {
/** This applies an action exactly one time step. Helper function to act(). */
reward_t oneStepAct(Action player_a_action, Action player_b_action);

/** This applies a continuous action exactly one time step.
* Helper function to actContinuous().
*/
reward_t oneStepActContinuous(
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire);


/** Actually emulates the emulator for a given number of steps. */
void emulate(Action player_a_action, Action player_b_action,
size_t num_steps = 1);
void emulateContinuous(
float player_a_r, float player_a_theta, float player_a_fire,
float player_b_r, float player_b_theta, float player_b_fire,
size_t num_steps = 1);

/** Drops illegal actions, such as the fire button in skiing. Note that this is different
* from the minimal set of actions. */
Expand Down Expand Up @@ -153,6 +176,7 @@ class StellaEnvironment {
int m_max_num_frames_per_episode; // Maxmimum number of frames per episode
size_t m_frame_skip; // How many frames to emulate per act()
float m_repeat_action_probability; // Stochasticity of the environment
float m_continuous_action_threshold; // Continuous action threshold
std::unique_ptr<ScreenExporter> m_screen_exporter; // Automatic screen recorder
int m_max_lives; // Maximum number of lives at the start of an episode.
bool m_truncate_on_loss_of_life; // Whether to truncate episodes on loss of life.
Expand All @@ -161,6 +185,9 @@ class StellaEnvironment {

// The last actions taken by our players
Action m_player_a_action, m_player_b_action;
float m_player_a_r, m_player_b_r;
float m_player_a_theta, m_player_b_theta;
float m_player_a_fire, m_player_b_fire;
};

} // namespace ale
Expand Down
1 change: 1 addition & 0 deletions src/python/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ALEInterface:
def __init__(self) -> None: ...
@overload
def act(self, action: Action) -> int: ...
def actContinuous(self, r: float, theta: float, fire: float) -> int: ...
@overload
def act(self, action: int) -> int: ...
def cloneState(self, *, include_rng: bool = False) -> ALEState: ...
Expand Down
1 change: 1 addition & 0 deletions src/python/ale_python_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ PYBIND11_MODULE(_ale_py, m) {
ale::ALEPythonInterface::act)
.def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action)) &
ale::ALEInterface::act)
.def("actContinuous", &ale::ALEPythonInterface::actContinuous)
.def("game_over", &ale::ALEPythonInterface::game_over, py::kw_only(), py::arg("with_truncation") = py::bool_(true))
.def("game_truncated", &ale::ALEPythonInterface::game_truncated)
.def("reset_game", &ale::ALEPythonInterface::reset_game)
Expand Down
Loading
Loading