Skip to content

Commit

Permalink
move things back to correct order
Browse files Browse the repository at this point in the history
  • Loading branch information
jjshoots committed Feb 6, 2024
1 parent 9562e9e commit 5e2fa7c
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions src/python/env/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,28 +207,6 @@ def load_game(self) -> None:
if self._game_difficulty is not None:
self.ale.setDifficulty(self._game_difficulty)


def reset( # pyright: ignore[reportIncompatibleMethodOverride]
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[np.ndarray, AtariEnvStepMetadata]:
"""Resets environment and returns initial observation."""
# sets the seeds if it's specified for both ALE and frameskip np
# we only want to do this when commanded to so we don't reset all previous states, statistics, etc.
seeded_with = self.seed_game(seed) if seed else None
self.load_game()

self.ale.reset_game()

obs = self._get_obs()
info = self._get_info()
if seeded_with is not None:
info["seeds"] = seeded_with

return obs, info

def step( # pyright: ignore[reportIncompatibleMethodOverride]
self,
action: int,
Expand Down Expand Up @@ -264,6 +242,27 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride]

return self._get_obs(), reward, is_terminal, is_truncated, self._get_info()

def reset( # pyright: ignore[reportIncompatibleMethodOverride]
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[np.ndarray, AtariEnvStepMetadata]:
"""Resets environment and returns initial observation."""
# sets the seeds if it's specified for both ALE and frameskip np
# we only want to do this when commanded to so we don't reset all previous states, statistics, etc.
seeded_with = self.seed_game(seed) if seed else None
self.load_game()

self.ale.reset_game()

obs = self._get_obs()
info = self._get_info()
if seeded_with is not None:
info["seeds"] = seeded_with

return obs, info

def render(self) -> Optional[np.ndarray]:
"""
Render is not supported by ALE. We use a paradigm similar to
Expand Down

0 comments on commit 5e2fa7c

Please sign in to comment.