Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more pre-commit hooks #510

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,61 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: debug-statements
- repo: https://github.com/psf/black
rev: 22.3.0
- id: check-symlinks
- id: destroyed-symlinks
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-ast
- id: check-added-large-files
- id: check-merge-conflict
- id: check-executables-have-shebangs
- id: check-shebang-scripts-are-executable
- id: detect-private-key
- id: debug-statements
# - repo: https://github.com/codespell-project/codespell
# rev: v2.2.6
# exclude: ^(src/common)|(src/emucore)|(src/environment)|(src/games)
# hooks:
# - id: codespell
# args:
# - --ignore-words-list=
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: black
- id: flake8
args:
# - '--per-file-ignores='
- --ignore=E203,W503,E741
- --max-complexity=30
- --max-line-length=456
- --show-source
- --statistics
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
args: ["--py38-plus"]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/python/black
rev: 23.12.1
hooks:
- id: black
# - repo: https://github.com/pycqa/pydocstyle
# rev: 6.3.0
# hooks:
# - id: pydocstyle
## exclude: ^
# args:
# - --source
# - --explain
# - --convention=google
# additional_dependencies: ["tomli"]
3 changes: 2 additions & 1 deletion examples/python-interface/python_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# ALE provided in examples/sharedLibraryInterfaceExample.cpp
import sys
from random import randrange
from ale_py import ALEInterface, SDL_SUPPORT

from ale_py import SDL_SUPPORT, ALEInterface

if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} rom_file")
Expand Down
4 changes: 2 additions & 2 deletions examples/python-interface/python_example_with_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# ALE provided in doc/examples/sharedLibraryInterfaceWithModesExample.cpp
import sys
from random import randrange
from ale_py import ALEInterface, SDL_SUPPORT

from ale_py import SDL_SUPPORT, ALEInterface

if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} rom_file")
Expand Down Expand Up @@ -40,7 +41,6 @@
# Play one episode in each mode and in each difficulty
for mode in avail_modes:
for diff in avail_diff:

ale.setDifficulty(diff)
ale.setMode(mode)
ale.reset_game()
Expand Down
2 changes: 1 addition & 1 deletion examples/python-rom-package/roms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
import pathlib
import sys

if sys.version_info >= (3, 9):
import importlib.resources as resources
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def parse_version(version_file):
semver_regex = r"(?P<major>0|[1-9]\d*)\.(?P<minor>0|[1-9]\d*)\.(?P<patch>0|[1-9]\d*)(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?"
semver_prog = re.compile(semver_regex)

with open(version_file, "r") as fp:
with open(version_file) as fp:
version = fp.read().strip()
assert semver_prog.match(version) is not None

Expand Down
8 changes: 7 additions & 1 deletion src/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
__version__ = "unknown"

# Import native shared library
from ale_py._ale_py import SDL_SUPPORT, Action, ALEInterface, ALEState, LoggerMode
from ale_py._ale_py import ( # noqa: E402
SDL_SUPPORT,
Action,
ALEInterface,
ALEState,
LoggerMode,
)

__all__ = ["Action", "ALEInterface", "ALEState", "LoggerMode", "SDL_SUPPORT"]
39 changes: 19 additions & 20 deletions src/python/gym_env.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

import sys
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Sequence

import ale_py
import ale_py.roms as roms
import ale_py.roms.utils as rom_utils
import numpy as np

import gym
from gym import error, spaces, utils, logger
import numpy as np
from gym import error, logger, spaces, utils

if sys.version_info < (3, 11):
from typing_extensions import NotRequired, TypedDict
Expand All @@ -36,14 +35,14 @@ class AtariEnv(gym.Env, utils.EzPickle):
def __init__(
self,
game: str = "pong",
mode: Optional[int] = None,
difficulty: Optional[int] = None,
mode: int | None = None,
difficulty: int | None = None,
obs_type: str = "rgb",
frameskip: Union[Tuple[int, int], int] = 4,
frameskip: tuple[int, int] | int = 4,
repeat_action_probability: float = 0.25,
full_action_space: bool = False,
max_num_frames_per_episode: Optional[int] = None,
render_mode: Optional[str] = None,
max_num_frames_per_episode: int | None = None,
render_mode: str | None = None,
) -> None:
"""
Initialize the ALE for Gym.
Expand All @@ -59,7 +58,7 @@ def __init__(
repeat_action_probability: int =>
Probability to repeat actions, see Machado et al., 2018
full_action_space: bool => Use full action space?
max_num_frames_per_episode: int => Max number of frame per epsiode.
max_num_frames_per_episode: int => Max number of frame per episode.
Once `max_num_frames_per_episode` is reached the episode is
truncated.
render_mode: str => One of { 'human', 'rgb_array' }.
Expand Down Expand Up @@ -102,11 +101,11 @@ def __init__(
)
elif isinstance(frameskip, tuple) and frameskip[0] > frameskip[1]:
raise error.Error(
f"Invalid stochastic frameskip, lower bound is greater than upper bound."
"Invalid stochastic frameskip, lower bound is greater than upper bound."
)
elif isinstance(frameskip, tuple) and frameskip[0] <= 0:
raise error.Error(
f"Invalid stochastic frameskip lower bound is greater than upper bound."
"Invalid stochastic frameskip lower bound is greater than upper bound."
)

if render_mode is not None and render_mode not in {"rgb_array", "human"}:
Expand Down Expand Up @@ -181,7 +180,7 @@ def __init__(
else:
raise error.Error(f"Unrecognized observation type: {self._obs_type}")

def seed(self, seed: Optional[int] = None) -> Tuple[int, int]:
def seed(self, seed: int | None = None) -> tuple[int, int]:
"""
Seeds both the internal numpy rng for stochastic frame skip
as well as the ALE RNG.
Expand Down Expand Up @@ -225,7 +224,7 @@ def seed(self, seed: Optional[int] = None) -> Tuple[int, int]:
def step(
self,
action_ind: int,
) -> Tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]:
) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]:
"""
Perform one agent step, i.e., repeats `action` frameskip # of steps.

Expand Down Expand Up @@ -263,9 +262,9 @@ def step(
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
) -> Tuple[np.ndarray, AtariEnvStepMetadata]:
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[np.ndarray, AtariEnvStepMetadata]:
"""
Resets environment and returns initial observation.
"""
Expand Down Expand Up @@ -307,7 +306,7 @@ def render(self) -> Any:

def _get_obs(self) -> np.ndarray:
"""
Retreives the current observation.
Retrieves the current observation.
This is dependent on `self._obs_type`.
"""
if self._obs_type == "ram":
Expand All @@ -326,7 +325,7 @@ def _get_info(self) -> AtariEnvStepMetadata:
"frame_number": self.ale.getFrameNumber(),
}

def get_keys_to_action(self) -> Dict[Tuple[int], ale_py.Action]:
def get_keys_to_action(self) -> dict[tuple[int], ale_py.Action]:
"""
Return keymapping -> actions for human play.
"""
Expand Down Expand Up @@ -369,7 +368,7 @@ def get_keys_to_action(self) -> Dict[Tuple[int], ale_py.Action]:
)
)

def get_action_meanings(self) -> List[str]:
def get_action_meanings(self) -> list[str]:
"""
Return the meaning of each integer action.
"""
Expand Down
7 changes: 3 additions & 4 deletions src/python/gym_registration.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Callable, Mapping, NamedTuple, Sequence, Text, Union
from typing import Any, Callable, Mapping, NamedTuple, Sequence

import ale_py.roms as roms
from ale_py.roms import utils as rom_utils

from gym.envs.registration import register


class GymFlavour(NamedTuple):
suffix: str
kwargs: Union[Mapping[Text, Any], Callable[[str], Mapping[Text, Any]]]
kwargs: Mapping[str, Any] | Callable[[str], Mapping[str, Any]]


class GymConfig(NamedTuple):
version: str
kwargs: Mapping[Text, Any]
kwargs: Mapping[str, Any]
flavours: Sequence[GymFlavour]


Expand Down
28 changes: 14 additions & 14 deletions src/python/gymnasium_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import Any, Literal, Optional, Sequence, Union
from typing import Any, Literal, Sequence

import ale_py
import gymnasium
Expand Down Expand Up @@ -36,14 +36,14 @@ class AtariEnv(gymnasium.Env, utils.EzPickle):
def __init__(
self,
game: str = "pong",
mode: Optional[int] = None,
difficulty: Optional[int] = None,
mode: int | None = None,
difficulty: int | None = None,
obs_type: Literal["rgb", "grayscale", "ram"] = "rgb",
frameskip: Union[tuple[int, int], int] = 4,
frameskip: tuple[int, int] | int = 4,
repeat_action_probability: float = 0.25,
full_action_space: bool = False,
max_num_frames_per_episode: Optional[int] = None,
render_mode: Optional[Literal["human", "rgb_array"]] = None,
max_num_frames_per_episode: int | None = None,
render_mode: Literal["human", "rgb_array"] | None = None,
) -> None:
"""
Initialize the ALE for Gymnasium.
Expand All @@ -59,7 +59,7 @@ def __init__(
repeat_action_probability: int =>
Probability to repeat actions, see Machado et al., 2018
full_action_space: bool => Use full action space?
max_num_frames_per_episode: int => Max number of frame per epsiode.
max_num_frames_per_episode: int => Max number of frame per episode.
Once `max_num_frames_per_episode` is reached the episode is
truncated.
render_mode: str => One of { 'human', 'rgb_array' }.
Expand Down Expand Up @@ -97,11 +97,11 @@ def __init__(
)
elif isinstance(frameskip, tuple) and frameskip[0] > frameskip[1]:
raise error.Error(
f"Invalid stochastic frameskip, lower bound is greater than upper bound."
"Invalid stochastic frameskip, lower bound is greater than upper bound."
)
elif isinstance(frameskip, tuple) and frameskip[0] <= 0:
raise error.Error(
f"Invalid stochastic frameskip lower bound is greater than upper bound."
"Invalid stochastic frameskip lower bound is greater than upper bound."
)

if render_mode is not None and render_mode not in {"rgb_array", "human"}:
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
else:
raise error.Error(f"Unrecognized observation type: {self._obs_type}")

def seed_game(self, seed: Optional[int] = None) -> tuple[int, int]:
def seed_game(self, seed: int | None = None) -> tuple[int, int]:
"""Seeds the internal and ALE RNG."""
ss = np.random.SeedSequence(seed)
np_seed, ale_seed = ss.generate_state(n_words=2)
Expand Down Expand Up @@ -232,8 +232,8 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride]
def reset( # pyright: ignore[reportIncompatibleMethodOverride]
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[np.ndarray, AtariEnvStepMetadata]:
"""Resets environment and returns initial observation."""
# sets the seeds if it's specified for both ALE and frameskip np
Expand All @@ -252,7 +252,7 @@ def reset( # pyright: ignore[reportIncompatibleMethodOverride]

return obs, info

def render(self) -> Optional[np.ndarray]:
def render(self) -> np.ndarray | None:
"""
Render is not supported by ALE. We use a paradigm similar to
Gym3 which allows you to specify `render_mode` during construction.
Expand All @@ -274,7 +274,7 @@ def render(self) -> Optional[np.ndarray]:

def _get_obs(self) -> np.ndarray:
"""
Retreives the current observation.
Retrieves the current observation.
This is dependent on `self._obs_type`.
"""
if self._obs_type == "ram":
Expand Down
9 changes: 4 additions & 5 deletions src/python/gymnasium_registration.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Callable, Mapping, NamedTuple, Sequence, Text, Union
from typing import Any, Callable, Mapping, NamedTuple, Sequence

import ale_py.roms as roms
from ale_py.roms import utils as rom_utils

import gymnasium
from ale_py.roms import utils as rom_utils


class GymFlavour(NamedTuple):
suffix: str
kwargs: Union[Mapping[Text, Any], Callable[[str], Mapping[Text, Any]]]
kwargs: Mapping[str, Any] | Callable[[str], Mapping[str, Any]]


class GymConfig(NamedTuple):
version: str
kwargs: Mapping[Text, Any]
kwargs: Mapping[str, Any]
flavours: Sequence[GymFlavour]


Expand Down
Loading
Loading