diff --git a/d3rlpy/cli.py b/d3rlpy/cli.py index 3c54fa80..24483b66 100644 --- a/d3rlpy/cli.py +++ b/d3rlpy/cli.py @@ -350,28 +350,26 @@ def play( @cli.command(short_help="Install additional packages.") @click.argument("name") -def install(name: str) -> None: +def _install_module( + name: list[str], upgrade: bool = False, check: bool = True +) -> None: if name == "atari": - subprocess.run( - ["pip3", "install", "-U", "gym[atari,accept-rom-license]"], - check=True, - ) + _install_module(["gym[atari,accept-rom-license]"], upgrade=True) elif name == "d4rl_atari": - subprocess.run(["d3rlpy", "install", "atari"], check=True) - subprocess.run( - ["pip3", "install", "git+https://github.com/takuseno/d4rl-atari"], - check=True, - ) + install("atari") + _install_module(["git+https://github.com/takuseno/d4rl-atari"]) elif name == "d4rl": - subprocess.run( - [ - "pip3", - "install", - "git+https://github.com/Farama-Foundation/D4RL", - ], - check=True, - ) - subprocess.run(["pip3", "install", "-U", "gym"], check=True) - subprocess.run(["pip3", "uninstall", "-y", "pybullet"], check=True) + _install_module(["git+https://github.com/Farama-Foundation/D4RL"]) + _install_module(["gym"], upgrade=True) + _install_module(["-y", "pybullet"], upgrade=True) + elif name == "minari": + _install_module(["minari==0.4.2"], upgrade=True) else: raise ValueError(f"Unsupported command: {name}") + + +def _install_module( + name: list[str], upgrade: bool = False, check: bool = True +) -> None: + name = ["-U", *name] if upgrade else name + subprocess.run(["pip3", "install", *name], check=check) diff --git a/d3rlpy/datasets.py b/d3rlpy/datasets.py index 180dcde8..ce916371 100644 --- a/d3rlpy/datasets.py +++ b/d3rlpy/datasets.py @@ -7,6 +7,7 @@ from urllib import request import gym +import gymnasium import numpy as np from gym.wrappers.time_limit import TimeLimit @@ -442,6 +443,69 @@ def get_d4rl( ) from e +def get_minari( + env_name: str, + transition_picker: Optional[TransitionPickerProtocol] = None, + trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, + render_mode: Optional[str] = None, +) -> Tuple[ReplayBuffer, gymnasium.Env[np.ndarray, np.ndarray]]: + """Returns minari dataset and envrironment. + + The dataset is provided through minari. + .. code-block:: python + from d3rlpy.datasets import get_minari + dataset, env = get_minari('door-cloned-v1') + Args: + env_name: environment id of minari dataset. + transition_picker: TransitionPickerProtocol object. + trajectory_slicer: TrajectorySlicerProtocol object. + render_mode: Mode of rendering (``human``, ``rgb_array``). + Returns: + tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. + """ + try: + import minari + + _dataset: minari.MinariDataset = minari.load_dataset( + env_name, download=True + ) + + data = { + "observations": [], + "actions": [], + "rewards": [], + "terminations": [], + "truncations": [], + } + + for ep in _dataset: + for key in data.keys(): + data[key].append(getattr(ep, key)) + + dataset = MDPDataset( + observations=np.concatenate(data["observations"]), + actions=np.concatenate(data["actions"]), + rewards=np.concatenate(data["rewards"]), + terminals=np.concatenate(data["terminations"]), + timeouts=np.concatenate(data["truncations"]), + transition_picker=transition_picker, + trajectory_slicer=trajectory_slicer, + ) + + env = _dataset.recover_environment() + unwrapped_env = env.unwrapped + + unwrapped_env.render_mode = render_mode + return dataset, TimeLimit( + unwrapped_env, max_episode_steps=env.spec.max_episode_steps + ) + + except ImportError as e: + raise ImportError( + "minari is not installed.\n" "$ d3rlpy install minari" + ) from e + + ATARI_GAMES = [ "adventure", "air-raid", diff --git a/docs/references/datasets.rst b/docs/references/datasets.rst index 2b48a96f..d7e01c13 100644 --- a/docs/references/datasets.rst +++ b/docs/references/datasets.rst @@ -16,3 +16,4 @@ learning algorithms. d3rlpy.datasets.get_atari_transitions d3rlpy.datasets.get_d4rl d3rlpy.datasets.get_dataset + d3rlpy.datasets.get_minari diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c86ccb1a..39e3e76e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,6 @@ import pytest -from d3rlpy.datasets import get_cartpole, get_dataset, get_pendulum +from d3rlpy.datasets import get_cartpole, get_dataset, get_minari, get_pendulum @pytest.mark.parametrize("dataset_type", ["replay", "random"]) @@ -23,3 +23,15 @@ def test_get_dataset(env_name: str) -> None: assert env.unwrapped.spec.id == "CartPole-v1" elif env_name == "pendulum-random": assert env.unwrapped.spec.id == "Pendulum-v1" + + +@pytest.mark.parametrize( + "dataset_name, env_name", + [ + ("door-cloned-v1", "AdroitHandDoor-v1"), + ("relocate-expert-v1", "AdroitHandRelocate-v1"), + ], +) +def test_get_minari(dataset_name: str, env_name: str) -> None: + _, env = get_minari(dataset_name) + assert env.unwrapped.spec.id == env_name