From 7be4f673129bca6c9c7a1a84f750a6a8aa4b52b4 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 23 Nov 2024 21:21:09 +0900 Subject: [PATCH] Upgrade minari dependency --- d3rlpy/cli.py | 2 +- reproductions/finetuning/awac_finetune.py | 2 +- reproductions/finetuning/iql_finetune.py | 2 +- tests/test_datasets.py | 69 +++++++++++------------ 4 files changed, 36 insertions(+), 39 deletions(-) diff --git a/d3rlpy/cli.py b/d3rlpy/cli.py index 67aa0f3f..2b35e18b 100644 --- a/d3rlpy/cli.py +++ b/d3rlpy/cli.py @@ -389,7 +389,7 @@ def print_available_options() -> None: _install_module(["gym"], upgrade=True) _uninstall_module(["pybullet"]) elif name == "minari": - _install_module(["minari==0.4.2", "gymnasium_robotics"], upgrade=True) + _install_module(["minari[all]>=0.5.1"], upgrade=True) elif name == "dm_control": _install_module(["shimmy[dm-control]==1.3.0"], upgrade=True) elif name == "list": diff --git a/reproductions/finetuning/awac_finetune.py b/reproductions/finetuning/awac_finetune.py index 970cd361..074125b5 100644 --- a/reproductions/finetuning/awac_finetune.py +++ b/reproductions/finetuning/awac_finetune.py @@ -6,7 +6,7 @@ def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--dataset", type=str, default="antmaze-umaze-v0") + parser.add_argument("--dataset", type=str, default="D4RL/antmaze/umaze-v1") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) parser.add_argument("--compile", action="store_true") diff --git a/reproductions/finetuning/iql_finetune.py b/reproductions/finetuning/iql_finetune.py index 1dc49da8..1a44fb29 100644 --- a/reproductions/finetuning/iql_finetune.py +++ b/reproductions/finetuning/iql_finetune.py @@ -7,7 +7,7 @@ def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--dataset", type=str, default="antmaze-umaze-v0") + parser.add_argument("--dataset", type=str, default="D4RL/antmaze/umaze-v1") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) parser.add_argument("--compile", action="store_true") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bd376395..83cd4186 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,6 @@ import pytest -from d3rlpy.datasets import get_cartpole, get_dataset, get_pendulum +from d3rlpy.datasets import get_cartpole, get_dataset, get_minari, get_pendulum @pytest.mark.parametrize("dataset_type", ["replay", "random"]) @@ -25,38 +25,35 @@ def test_get_dataset(env_name: str) -> None: assert env.unwrapped.spec.id == "Pendulum-v1" -# @pytest.mark.parametrize( -# "dataset_name, env_name", -# [ -# ("door-cloned-v2", "AdroitHandDoor-v1"), -# ("relocate-expert-v2", "AdroitHandRelocate-v1"), -# ("kitchen-complete-v1", "FrankaKitchen-v1"), -# ], -# ) -# @pytest.mark.parametrize("tuple_observation", [False, True]) -# def test_get_minari( -# dataset_name: str, env_name: str, tuple_observation: bool -# ) -> None: -# dataset, env = get_minari( -# dataset_name, tuple_observation=tuple_observation -# ) -# assert env.unwrapped.spec.id == env_name # type: ignore -# -# if tuple_observation: -# # check shape -# ep = dataset.episodes[0] -# ref_shape0 = ep.observations[0].shape[1:] -# ref_shape1 = ep.observations[1].shape[1:] -# obs, _ = env.reset() -# assert obs[0].shape == ref_shape0 -# assert obs[1].shape == ref_shape1 -# obs, _, _, _, _ = env.step(env.action_space.sample()) -# assert obs[0].shape == ref_shape0 -# assert obs[1].shape == ref_shape1 -# else: -# # check shape -# ref_shape = dataset.episodes[0].observations.shape[1:] # type: ignore -# obs, _ = env.reset() -# assert obs.shape == ref_shape -# obs, _, _, _, _ = env.step(env.action_space.sample()) -# assert obs.shape == ref_shape +@pytest.mark.parametrize( + "dataset_name, env_name", + [ + ("D4RL/door/cloned-v2", "AdroitHandDoor-v1"), + ("D4RL/kitchen/complete-v2", "FrankaKitchen-v1"), + ], +) +@pytest.mark.parametrize("tuple_observation", [False, True]) +def test_get_minari( + dataset_name: str, env_name: str, tuple_observation: bool +) -> None: + dataset, env = get_minari(dataset_name, tuple_observation=tuple_observation) + assert env.unwrapped.spec.id == env_name # type: ignore + + if tuple_observation: + # check shape + ep = dataset.episodes[0] + ref_shape0 = ep.observations[0].shape[1:] + ref_shape1 = ep.observations[1].shape[1:] + obs, _ = env.reset() + assert obs[0].shape == ref_shape0 + assert obs[1].shape == ref_shape1 + obs, _, _, _, _ = env.step(env.action_space.sample()) + assert obs[0].shape == ref_shape0 + assert obs[1].shape == ref_shape1 + else: + # check shape + ref_shape = dataset.episodes[0].observations.shape[1:] # type: ignore + obs, _ = env.reset() + assert obs.shape == ref_shape + obs, _, _, _, _ = env.step(env.action_space.sample()) + assert obs.shape == ref_shape