Skip to content

Commit

Permalink
Add render_mode option to datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jul 21, 2023
1 parent a7207cd commit 2121edd
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions d3rlpy/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
load_v1,
)
from .envs import ChannelFirst, FrameStack
from .logging import LOG

__all__ = [
"DATA_DIRECTORY",
Expand Down Expand Up @@ -51,6 +52,7 @@ def get_cartpole(
dataset_type: str = "replay",
transition_picker: Optional[TransitionPickerProtocol] = None,
trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]:
"""Returns cartpole dataset and environment.
Expand All @@ -62,6 +64,7 @@ def get_cartpole(
``['replay', 'random']``.
transition_picker: TransitionPickerProtocol object.
trajectory_slicer: TrajectorySlicerProtocol object.
render_mode: Mode of rendering (``human``, ``rgb_array``).
Returns:
tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
Expand Down Expand Up @@ -94,7 +97,7 @@ def get_cartpole(
)

# environment
env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1", render_mode=render_mode)

return dataset, env

Expand All @@ -103,6 +106,7 @@ def get_pendulum(
dataset_type: str = "replay",
transition_picker: Optional[TransitionPickerProtocol] = None,
trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, np.ndarray]]:
"""Returns pendulum dataset and environment.
Expand All @@ -114,6 +118,7 @@ def get_pendulum(
``['replay', 'random']``.
transition_picker: TransitionPickerProtocol object.
trajectory_slicer: TrajectorySlicerProtocol object.
render_mode: Mode of rendering (``human``, ``rgb_array``).
Returns:
tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
Expand Down Expand Up @@ -145,14 +150,15 @@ def get_pendulum(
)

# environment
env = gym.make("Pendulum-v1")
env = gym.make("Pendulum-v1", render_mode=render_mode)

return dataset, env


def get_atari(
env_name: str,
num_stack: Optional[int] = None,
render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]:
"""Returns atari dataset and envrironment.
Expand All @@ -171,14 +177,15 @@ def get_atari(
Args:
env_name: environment id of d4rl-atari dataset.
num_stack: the number of frames to stack (only applied to env).
render_mode: Mode of rendering (``human``, ``rgb_array``).
Returns:
tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
"""
try:
import d4rl_atari # type: ignore

env = gym.make(env_name)
env = gym.make(env_name, render_mode=render_mode)
raw_dataset = env.get_dataset() # type: ignore
episode_generator = EpisodeGenerator(**raw_dataset)
dataset = create_infinite_replay_buffer(
Expand All @@ -203,6 +210,7 @@ def get_atari_transitions(
fraction: float = 0.01,
index: int = 0,
num_stack: Optional[int] = None,
render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]:
"""Returns atari dataset as a list of Transition objects and envrironment.
Expand All @@ -226,6 +234,7 @@ def get_atari_transitions(
fraction: fraction of sampled transitions.
index: index to specify which trial to load.
num_stack: the number of frames to stack (only applied to env).
render_mode: Mode of rendering (``human``, ``rgb_array``).
Returns:
tuple of a list of :class:`d3rlpy.dataset.Transition` and gym
Expand All @@ -239,8 +248,12 @@ def get_atari_transitions(

copied_episodes = []
for i in range(50):
env_name = f"{game_name}-epoch-{i + 1}-v{index}"
LOG.info(f"Collecting {env_name}...")
env = gym.make(
f"{game_name}-epoch-{i + 1}-v{index}", sticky_action=True
env_name,
sticky_action=True,
render_mode=render_mode,
)
raw_dataset = env.get_dataset() # type: ignore
episode_generator = EpisodeGenerator(**raw_dataset)
Expand Down Expand Up @@ -296,6 +309,7 @@ def get_d4rl(
env_name: str,
transition_picker: Optional[TransitionPickerProtocol] = None,
trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, np.ndarray]]:
"""Returns d4rl dataset and envrironment.
Expand All @@ -316,6 +330,7 @@ def get_d4rl(
env_name: environment id of d4rl dataset.
transition_picker: TransitionPickerProtocol object.
trajectory_slicer: TrajectorySlicerProtocol object.
render_mode: Mode of rendering (``human``, ``rgb_array``).
Returns:
tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
Expand Down Expand Up @@ -344,6 +359,7 @@ def get_d4rl(

# wrapped by NormalizedBoxEnv that is incompatible with newer Gym
unwrapped_env: gym.Env[Any, Any] = env.env.env.env.wrapped_env # type: ignore
unwrapped_env.render_mode = render_mode # overwrite

return dataset, TimeLimit(unwrapped_env, max_episode_steps=1000)
except ImportError as e:
Expand Down Expand Up @@ -423,6 +439,7 @@ def get_dataset(
env_name: str,
transition_picker: Optional[TransitionPickerProtocol] = None,
trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[Any, Any]]:
"""Returns dataset and envrironment by guessing from name.
Expand Down Expand Up @@ -456,6 +473,7 @@ def get_dataset(
env_name: environment id of the dataset.
transition_picker: TransitionPickerProtocol object.
trajectory_slicer: TrajectorySlicerProtocol object.
render_mode: Mode of rendering (``human``, ``rgb_array``).
Returns:
tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
Expand All @@ -465,35 +483,41 @@ def get_dataset(
dataset_type="replay",
transition_picker=transition_picker,
trajectory_slicer=trajectory_slicer,
render_mode=render_mode,
)
elif env_name == "cartpole-random":
return get_cartpole(
dataset_type="random",
transition_picker=transition_picker,
trajectory_slicer=trajectory_slicer,
render_mode=render_mode,
)
elif env_name == "pendulum-replay":
return get_pendulum(
dataset_type="replay",
transition_picker=transition_picker,
trajectory_slicer=trajectory_slicer,
render_mode=render_mode,
)
elif env_name == "pendulum-random":
return get_pendulum(
dataset_type="random",
transition_picker=transition_picker,
trajectory_slicer=trajectory_slicer,
render_mode=render_mode,
)
elif re.match(r"^bullet-.+$", env_name):
return get_d4rl(
env_name,
transition_picker=transition_picker,
trajectory_slicer=trajectory_slicer,
render_mode=render_mode,
)
elif re.match(r"hopper|halfcheetah|walker|ant", env_name):
return get_d4rl(
env_name,
transition_picker=transition_picker,
trajectory_slicer=trajectory_slicer,
render_mode=render_mode,
)
raise ValueError(f"Unrecognized env_name: {env_name}.")

0 comments on commit 2121edd

Please sign in to comment.