From 2121edd853ea32a9fe6d262de639c69c13b50c24 Mon Sep 17 00:00:00 2001 From: takuseno Date: Fri, 21 Jul 2023 17:29:45 +0900 Subject: [PATCH] Add render_mode option to datasets --- d3rlpy/datasets.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/d3rlpy/datasets.py b/d3rlpy/datasets.py index ac8c73b5..ac3b84d3 100644 --- a/d3rlpy/datasets.py +++ b/d3rlpy/datasets.py @@ -23,6 +23,7 @@ load_v1, ) from .envs import ChannelFirst, FrameStack +from .logging import LOG __all__ = [ "DATA_DIRECTORY", @@ -51,6 +52,7 @@ def get_cartpole( dataset_type: str = "replay", transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, + render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]: """Returns cartpole dataset and environment. @@ -62,6 +64,7 @@ def get_cartpole( ``['replay', 'random']``. transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. + render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. @@ -94,7 +97,7 @@ def get_cartpole( ) # environment - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", render_mode=render_mode) return dataset, env @@ -103,6 +106,7 @@ def get_pendulum( dataset_type: str = "replay", transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, + render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, np.ndarray]]: """Returns pendulum dataset and environment. @@ -114,6 +118,7 @@ def get_pendulum( ``['replay', 'random']``. transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. + render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. @@ -145,7 +150,7 @@ def get_pendulum( ) # environment - env = gym.make("Pendulum-v1") + env = gym.make("Pendulum-v1", render_mode=render_mode) return dataset, env @@ -153,6 +158,7 @@ def get_pendulum( def get_atari( env_name: str, num_stack: Optional[int] = None, + render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]: """Returns atari dataset and envrironment. @@ -171,6 +177,7 @@ def get_atari( Args: env_name: environment id of d4rl-atari dataset. num_stack: the number of frames to stack (only applied to env). + render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. @@ -178,7 +185,7 @@ def get_atari( try: import d4rl_atari # type: ignore - env = gym.make(env_name) + env = gym.make(env_name, render_mode=render_mode) raw_dataset = env.get_dataset() # type: ignore episode_generator = EpisodeGenerator(**raw_dataset) dataset = create_infinite_replay_buffer( @@ -203,6 +210,7 @@ def get_atari_transitions( fraction: float = 0.01, index: int = 0, num_stack: Optional[int] = None, + render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]: """Returns atari dataset as a list of Transition objects and envrironment. @@ -226,6 +234,7 @@ def get_atari_transitions( fraction: fraction of sampled transitions. index: index to specify which trial to load. num_stack: the number of frames to stack (only applied to env). + render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of a list of :class:`d3rlpy.dataset.Transition` and gym @@ -239,8 +248,12 @@ def get_atari_transitions( copied_episodes = [] for i in range(50): + env_name = f"{game_name}-epoch-{i + 1}-v{index}" + LOG.info(f"Collecting {env_name}...") env = gym.make( - f"{game_name}-epoch-{i + 1}-v{index}", sticky_action=True + env_name, + sticky_action=True, + render_mode=render_mode, ) raw_dataset = env.get_dataset() # type: ignore episode_generator = EpisodeGenerator(**raw_dataset) @@ -296,6 +309,7 @@ def get_d4rl( env_name: str, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, + render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, np.ndarray]]: """Returns d4rl dataset and envrironment. @@ -316,6 +330,7 @@ def get_d4rl( env_name: environment id of d4rl dataset. transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. + render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. @@ -344,6 +359,7 @@ def get_d4rl( # wrapped by NormalizedBoxEnv that is incompatible with newer Gym unwrapped_env: gym.Env[Any, Any] = env.env.env.env.wrapped_env # type: ignore + unwrapped_env.render_mode = render_mode # overwrite return dataset, TimeLimit(unwrapped_env, max_episode_steps=1000) except ImportError as e: @@ -423,6 +439,7 @@ def get_dataset( env_name: str, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, + render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[Any, Any]]: """Returns dataset and envrironment by guessing from name. @@ -456,6 +473,7 @@ def get_dataset( env_name: environment id of the dataset. transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. + render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. @@ -465,35 +483,41 @@ def get_dataset( dataset_type="replay", transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, + render_mode=render_mode, ) elif env_name == "cartpole-random": return get_cartpole( dataset_type="random", transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, + render_mode=render_mode, ) elif env_name == "pendulum-replay": return get_pendulum( dataset_type="replay", transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, + render_mode=render_mode, ) elif env_name == "pendulum-random": return get_pendulum( dataset_type="random", transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, + render_mode=render_mode, ) elif re.match(r"^bullet-.+$", env_name): return get_d4rl( env_name, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, + render_mode=render_mode, ) elif re.match(r"hopper|halfcheetah|walker|ant", env_name): return get_d4rl( env_name, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, + render_mode=render_mode, ) raise ValueError(f"Unrecognized env_name: {env_name}.")