diff --git a/d3rlpy/cli.py b/d3rlpy/cli.py index 332e04af..1a7fe308 100644 --- a/d3rlpy/cli.py +++ b/d3rlpy/cli.py @@ -359,9 +359,25 @@ def _uninstall_module(name: List[str], check: bool = True) -> None: subprocess.run(["pip3", "uninstall", "-y", *name], check=check) +INSTALL_OPTIONS = { + "atari": "Atari 2600 environments for Gym.", + "d4rl_atari": "Datasets for Atari 2600. https://github.com/takuseno/d4rl-atari", + "d4rl": "D4RL.", + "minari": "Minari.", + "dm_control": "DeepMind Control environments via Shimmy.", +} + + @cli.command(short_help="Install additional packages.") @click.argument("name") def install(name: str) -> None: + + def print_available_options() -> None: + print("List of available options.") + for name, description in INSTALL_OPTIONS.items(): + padding = " " * (15 - len(name)) + print(f"{name + padding}: {description}") + if name == "atari": _install_module(["gym[atari,accept-rom-license]"], upgrade=True) elif name == "d4rl_atari": @@ -375,5 +391,8 @@ def install(name: str) -> None: _install_module(["minari==0.4.2", "gymnasium_robotics"], upgrade=True) elif name == "dm_control": _install_module(["shimmy[dm-control]==1.3.0"], upgrade=True) + elif name == "list": + print_available_options() else: + print_available_options() raise ValueError(f"Unsupported command: {name}")