Skip to content

Commit

Permalink
move discretization to Python
Browse files Browse the repository at this point in the history
  • Loading branch information
jjshoots committed Aug 3, 2024
1 parent 9aef3d7 commit 3a971e8
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 28 deletions.
4 changes: 3 additions & 1 deletion src/environment/ale_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ void ALEState::applyActionPaddles(Event* event,
case PLAYER_A_DOWNLEFTFIRE:
delta_a = static_cast<int>(PADDLE_DELTA * fabs(paddle_a_strength));
break;

default:
break;
}
Expand All @@ -237,6 +238,7 @@ void ALEState::applyActionPaddles(Event* event,
case PLAYER_B_DOWNLEFTFIRE:
delta_b = static_cast<int>(PADDLE_DELTA * fabs(paddle_b_strength));
break;

default:
break;
}
Expand Down Expand Up @@ -300,7 +302,7 @@ void ALEState::setDifficultySwitches(Event* event, unsigned int value) {
event->set(Event::ConsoleRightDifficultyB, !((value & 2) >> 1));
}

void ALEState::setActionJoysticks(Event* event,
void ALEState::applyActionJoysticks(Event* event,
int player_a_action, int player_b_action) {
// Reset keys
resetKeys(event);
Expand Down
124 changes: 97 additions & 27 deletions src/python/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from functools import lru_cache
import sys
from typing import Any, Literal
from typing import Any, Literal, Sequence

import ale_py
import gymnasium
Expand Down Expand Up @@ -142,8 +143,6 @@ def __init__(
self.ale.setLoggerMode(ale_py.LoggerMode.Error)
# Config sticky action prob.
self.ale.setFloat("repeat_action_probability", repeat_action_probability)
# Config continuous action threshold (if using continuous actions).
self.ale.setFloat("continuous_action_threshold", continuous_action_threshold)

if max_num_frames_per_episode is not None:
self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode)
Expand All @@ -157,23 +156,24 @@ def __init__(
self.seed_game()
self.load_game()

# get the set of legal actions
self._action_set = (
self.ale.getLegalActionSet()
if full_action_space
else self.ale.getMinimalActionSet()
)

# action space
self.continuous = continuous
self.continuous_action_threshold = continuous_action_threshold
if continuous:
# We don't need action_set for continuous actions.
self._action_set = None
# Actions are radius, theta, and fire, where first two are the
# parameters of polar coordinates.
self.action_space = spaces.Box(
np.array([0, -np.pi, 0]).astype(np.float32),
np.array([+1, +np.pi, +1]).astype(np.float32),
np.array([0.0, -np.pi, 0.0]).astype(np.float32),
np.array([1.0, np.pi, 1.0]).astype(np.float32),
) # radius, theta, fire. First two are polar coordinates.
else:
self._action_set = (
self.ale.getLegalActionSet()
if full_action_space
else self.ale.getMinimalActionSet()
)
self.action_space = spaces.Discrete(len(self._action_set))

# initialize observation space
Expand Down Expand Up @@ -239,14 +239,15 @@ def reset( # pyright: ignore[reportIncompatibleMethodOverride]

def step( # pyright: ignore[reportIncompatibleMethodOverride]
self,
action: int | np.ndarray,
action: int | Sequence[float],
) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]:
"""
Perform one agent step, i.e., repeats `action` frameskip # of steps.
Args:
action_ind: int | np.ndarray => Action index to execute, or numpy
array of floats if continuous.
action: int | Sequence[float] =>
if `continuous=False` -> action index to execute
if `continuous=True` -> numpy array of r, theta, fire
Returns:
tuple[np.ndarray, float, bool, bool, Dict[str, Any]] =>
Expand All @@ -268,13 +269,23 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride]
reward = 0.0
for _ in range(frameskip):
if self.continuous:
if len(action) != 3:
raise error.Error("Actions must have 3-dimensions.")

if isinstance(action, np.ndarray):
action = action.tolist()
r, theta, fire = action
reward += self.ale.actContinuous(r, theta, fire)
# compute the x, y, fire of the joystick
assert isinstance(action, Sequence)
strength = action[0]
x, y = action[0] * np.cos(action[1]), action[0] * np.sin(action[1])
action = self.map_action_idx(
left_center_right=(
-(x < self.continuous_action_threshold)
+(x > self.continuous_action_threshold)
),
down_up_center=(
-(y < self.continuous_action_threshold)
+(y > self.continuous_action_threshold)
),
fire=(action[-1] > self.continuous_action_threshold),
)

reward += self.ale.act(action, strength)
else:
reward += self.ale.act(self._action_set[action])
is_terminal = self.ale.game_over(with_truncation=False)
Expand Down Expand Up @@ -323,6 +334,7 @@ def _get_info(self) -> AtariEnvStepMetadata:
"frame_number": self.ale.getFrameNumber(),
}

@lru_cache(1)
def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]:
"""
Return keymapping -> actions for human play.
Expand Down Expand Up @@ -358,12 +370,70 @@ def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]:
# (key, key, ...) -> action_idx
# where action_idx is the integer value of the action enum
#
return dict(
zip(
map(lambda action: tuple(sorted(mapping[action])), self._action_set),
range(len(self._action_set)),
return {
tuple(sorted(mapping[act_idx])): act_idx
for act_idx in self.action_set
}

@lru_cache(18)
def map_action_idx(self, left_center_right: int, down_center_up: int, fire: bool) -> int:
"""
Return an action idx given unit actions for underlying env.
"""
# no op and fire
if left_center_right == 0 and down_center_up == 0 and not fire:
return ale_py.Action.NOOP
elif left_center_right == 0 and down_center_up == 0 and fire:
return ale_py.Action.FIRE

# cardinal no fire
elif left_center_right == -1 and down_center_up == 0 and not fire:
return ale_py.Action.LEFT
elif left_center_right == 1 and down_center_up == 0 and not fire:
return ale_py.Action.RIGHT
elif left_center_right == 0 and down_center_up == -1 and not fire:
return ale_py.Action.DOWN
elif left_center_right == 0 and down_center_up == 1 and not fire:
return ale_py.Action.UP

# cardinal fire
if left_center_right == -1 and down_center_up == 0 and fire:
return ale_py.Action.LEFTFIRE
elif left_center_right == 1 and down_center_up == 0 and fire:
return ale_py.Action.RIGHTFIRE
elif left_center_right == 0 and down_center_up == -1 and fire:
return ale_py.Action.DOWNFIRE
elif left_center_right == 0 and down_center_up == 1 and fire:
return ale_py.Action.UPFIRE

# diagonal no fire
elif left_center_right == -1 and down_center_up == -1 and not fire:
return ale_py.Action.DOWNLEFT
elif left_center_right == 1 and down_center_up == -1 and not fire:
return ale_py.Action.DOWNRIGHT
elif left_center_right == -1 and down_center_up == 1 and not fire:
return ale_py.Action.UPLEFT
elif left_center_right == 1 and down_center_up == 1 and not fire:
return ale_py.Action.UPRIGHT

# diagonal fire
elif left_center_right == -1 and down_center_up == -1 and fire:
return ale_py.Action.DOWNLEFTFIRE
elif left_center_right == 1 and down_center_up == -1 and fire:
return ale_py.Action.DOWNRIGHTFIRE
elif left_center_right == -1 and down_center_up == 1 and fire:
return ale_py.Action.UPLEFTFIRE
elif left_center_right == 1 and down_center_up == 1 and fire:
return ale_py.Action.UPRIGHTFIRE

# just in case
else:
raise LookupError(
"Did not expect to get here, "
"expected `left_center_right` and `down_center_up` to be in {-1, 0, 1} "
"and `fire` to only be `True` or `False`. "
f"Received {left_center_right=}, {down_center_up=} and {fire=}."
)
)

def get_action_meanings(self) -> list[str]:
"""
Expand Down

0 comments on commit 3a971e8

Please sign in to comment.