Skip to content

Commit

Permalink
Add DMC support via Shimmy
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jul 6, 2024
1 parent 29c82d5 commit 9d97d16
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 4 deletions.
11 changes: 11 additions & 0 deletions d3rlpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random

import gymnasium
import numpy as np
import torch

Expand Down Expand Up @@ -64,3 +65,13 @@ def seed(n: int) -> None:

# run healthcheck
run_healthcheck()


# register Shimmy if available
try:
import shimmy

gymnasium.register_envs(shimmy)
logging.LOG.info("Register Shimmy environments.")
except ImportError:
pass
2 changes: 2 additions & 0 deletions d3rlpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,5 +373,7 @@ def install(name: str) -> None:
_uninstall_module(["pybullet"])
elif name == "minari":
_install_module(["minari==0.4.2", "gymnasium_robotics"], upgrade=True)
elif name == "dm_control":
_install_module(["shimmy[dm-control]"], upgrade=True)
else:
raise ValueError(f"Unsupported command: {name}")
2 changes: 1 addition & 1 deletion d3rlpy/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gym.wrappers.time_limit import TimeLimit
from gymnasium.spaces import Box as GymnasiumBox
from gymnasium.spaces import Dict as GymnasiumDictSpace
from gymnasium.wrappers.time_limit import TimeLimit as GymnasiumTimeLimit
from gymnasium.wrappers import TimeLimit as GymnasiumTimeLimit

from .dataset import (
BasicTrajectorySlicer,
Expand Down
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
follow_imports = skip
follow_imports_for_stubs = True

[mypy-shimmy.*]
ignore_missing_imports = True
follow_imports = skip
follow_imports_for_stubs = True
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
torch==2.0.1
tqdm>=4.66.3
tqdm>=4.66.1
h5py==2.10.0
gym==0.26.2
click==8.0.1
typing-extensions==3.7.4.3
structlog==20.2.0
colorama==0.4.4
gymnasium==0.29.0
gymnasium==1.0.0a1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"structlog",
"colorama",
"dataclasses-json",
"gymnasium",
"gymnasium>=1.0.0a1",
],
packages=find_packages(exclude=["tests*"]),
python_requires=">=3.8.0",
Expand Down

0 comments on commit 9d97d16

Please sign in to comment.