diff --git a/d3rlpy/dataset/transition_pickers.py b/d3rlpy/dataset/transition_pickers.py index 146a413e..c40533d2 100644 --- a/d3rlpy/dataset/transition_pickers.py +++ b/d3rlpy/dataset/transition_pickers.py @@ -1,6 +1,9 @@ +import dataclasses + import numpy as np from typing_extensions import Protocol +from ..types import Float32NDArray from .components import EpisodeBase, Transition from .utils import ( create_zero_observation, @@ -11,6 +14,7 @@ __all__ = [ "TransitionPickerProtocol", "BasicTransitionPicker", + "SparseRewardTransitionPicker", "FrameStackTransitionPicker", "MultiStepTransitionPicker", ] @@ -65,6 +69,37 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition: ) +class SparseRewardTransitionPicker(TransitionPickerProtocol): + r"""Sparse reward transition picker. + + This class extends BasicTransitionPicker to handle special returns_to_go + calculation mainly used in AntMaze environments. + + Args: + horizon_length (int): Length to repeat rewards_to_go. + step_reward (float): Immediate step reward value in sparse reward + setting. + """ + + def __init__(self, horizon_length: int, step_reward: float = 0.0): + self._horizon_length = horizon_length + self._step_reward = step_reward + self._transition_picker = BasicTransitionPicker() + + def __call__(self, episode: EpisodeBase, index: int) -> Transition: + transition = self._transition_picker(episode, index) + if np.all(transition.rewards_to_go == self._step_reward): + extended_rewards_to_go: Float32NDArray = np.array( + [[self._step_reward]] * self._horizon_length, + dtype=np.float32, + ) + transition = dataclasses.replace( + transition, + rewards_to_go=extended_rewards_to_go, + ) + return transition + + class FrameStackTransitionPicker(TransitionPickerProtocol): r"""Frame-stacking transition picker. diff --git a/reproductions/finetuning/cal_ql_finetune.py b/reproductions/finetuning/cal_ql_finetune.py index 3e5beb90..7e9f0077 100644 --- a/reproductions/finetuning/cal_ql_finetune.py +++ b/reproductions/finetuning/cal_ql_finetune.py @@ -12,7 +12,13 @@ def main() -> None: parser.add_argument("--gpu", type=int) args = parser.parse_args() - dataset, env = d3rlpy.datasets.get_d4rl(args.dataset) + dataset, env = d3rlpy.datasets.get_d4rl( + args.dataset, + transition_picker=d3rlpy.dataset.SparseRewardTransitionPicker( + horizon_length=100, + step_reward=0, + ), + ) # fix seed d3rlpy.seed(args.seed) diff --git a/tests/dataset/test_transition_pickers.py b/tests/dataset/test_transition_pickers.py index c7c88bc9..7944c9ea 100644 --- a/tests/dataset/test_transition_pickers.py +++ b/tests/dataset/test_transition_pickers.py @@ -1,3 +1,5 @@ +import dataclasses + import numpy as np import pytest @@ -5,6 +7,7 @@ BasicTransitionPicker, FrameStackTransitionPicker, MultiStepTransitionPicker, + SparseRewardTransitionPicker, ) from d3rlpy.types import Shape @@ -185,3 +188,40 @@ def test_multi_step_transition_picker( assert np.allclose(transition.rewards_to_go, episode.rewards[-n_steps:]) assert transition.interval == n_steps assert transition.terminal == 1.0 + + +@pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("length", [100]) +@pytest.mark.parametrize("horizon_length", [1000]) +@pytest.mark.parametrize("step_reward", [0.0]) +@pytest.mark.parametrize("success", [False, True]) +def test_sparse_reward_transition_picker( + observation_shape: Shape, + action_size: int, + length: int, + horizon_length: int, + step_reward: float, + success: bool, +) -> None: + episode = create_episode( + observation_shape, action_size, length, terminated=True + ) + if not success: + episode = dataclasses.replace( + episode, + rewards=np.zeros_like(episode.rewards), + ) + + picker = SparseRewardTransitionPicker( + horizon_length=horizon_length, + step_reward=step_reward, + ) + + transition = picker(episode, 0) + + if success: + assert np.all(transition.rewards_to_go == episode.rewards) + else: + assert transition.rewards_to_go.shape == (horizon_length, 1) + assert np.all(transition.rewards_to_go == 0)