Skip to content

Commit

Permalink
Upgrade minari dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 23, 2024
1 parent 07cc474 commit 7be4f67
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 39 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def print_available_options() -> None:
_install_module(["gym"], upgrade=True)
_uninstall_module(["pybullet"])
elif name == "minari":
_install_module(["minari==0.4.2", "gymnasium_robotics"], upgrade=True)
_install_module(["minari[all]>=0.5.1"], upgrade=True)
elif name == "dm_control":
_install_module(["shimmy[dm-control]==1.3.0"], upgrade=True)
elif name == "list":
Expand Down
2 changes: 1 addition & 1 deletion reproductions/finetuning/awac_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="antmaze-umaze-v0")
parser.add_argument("--dataset", type=str, default="D4RL/antmaze/umaze-v1")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--gpu", type=int)
parser.add_argument("--compile", action="store_true")
Expand Down
2 changes: 1 addition & 1 deletion reproductions/finetuning/iql_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="antmaze-umaze-v0")
parser.add_argument("--dataset", type=str, default="D4RL/antmaze/umaze-v1")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--gpu", type=int)
parser.add_argument("--compile", action="store_true")
Expand Down
69 changes: 33 additions & 36 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from d3rlpy.datasets import get_cartpole, get_dataset, get_pendulum
from d3rlpy.datasets import get_cartpole, get_dataset, get_minari, get_pendulum


@pytest.mark.parametrize("dataset_type", ["replay", "random"])
Expand All @@ -25,38 +25,35 @@ def test_get_dataset(env_name: str) -> None:
assert env.unwrapped.spec.id == "Pendulum-v1"


# @pytest.mark.parametrize(
# "dataset_name, env_name",
# [
# ("door-cloned-v2", "AdroitHandDoor-v1"),
# ("relocate-expert-v2", "AdroitHandRelocate-v1"),
# ("kitchen-complete-v1", "FrankaKitchen-v1"),
# ],
# )
# @pytest.mark.parametrize("tuple_observation", [False, True])
# def test_get_minari(
# dataset_name: str, env_name: str, tuple_observation: bool
# ) -> None:
# dataset, env = get_minari(
# dataset_name, tuple_observation=tuple_observation
# )
# assert env.unwrapped.spec.id == env_name # type: ignore
#
# if tuple_observation:
# # check shape
# ep = dataset.episodes[0]
# ref_shape0 = ep.observations[0].shape[1:]
# ref_shape1 = ep.observations[1].shape[1:]
# obs, _ = env.reset()
# assert obs[0].shape == ref_shape0
# assert obs[1].shape == ref_shape1
# obs, _, _, _, _ = env.step(env.action_space.sample())
# assert obs[0].shape == ref_shape0
# assert obs[1].shape == ref_shape1
# else:
# # check shape
# ref_shape = dataset.episodes[0].observations.shape[1:] # type: ignore
# obs, _ = env.reset()
# assert obs.shape == ref_shape
# obs, _, _, _, _ = env.step(env.action_space.sample())
# assert obs.shape == ref_shape
@pytest.mark.parametrize(
"dataset_name, env_name",
[
("D4RL/door/cloned-v2", "AdroitHandDoor-v1"),
("D4RL/kitchen/complete-v2", "FrankaKitchen-v1"),
],
)
@pytest.mark.parametrize("tuple_observation", [False, True])
def test_get_minari(
dataset_name: str, env_name: str, tuple_observation: bool
) -> None:
dataset, env = get_minari(dataset_name, tuple_observation=tuple_observation)
assert env.unwrapped.spec.id == env_name # type: ignore

if tuple_observation:
# check shape
ep = dataset.episodes[0]
ref_shape0 = ep.observations[0].shape[1:]
ref_shape1 = ep.observations[1].shape[1:]
obs, _ = env.reset()
assert obs[0].shape == ref_shape0
assert obs[1].shape == ref_shape1
obs, _, _, _, _ = env.step(env.action_space.sample())
assert obs[0].shape == ref_shape0
assert obs[1].shape == ref_shape1
else:
# check shape
ref_shape = dataset.episodes[0].observations.shape[1:] # type: ignore
obs, _ = env.reset()
assert obs.shape == ref_shape
obs, _, _, _, _ = env.step(env.action_space.sample())
assert obs.shape == ref_shape

0 comments on commit 7be4f67

Please sign in to comment.