Skip to content

Commit

Permalink
Add SparseRewardTransitionPicker for CalQL
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed May 7, 2024
1 parent 570b334 commit 645b49d
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
35 changes: 35 additions & 0 deletions d3rlpy/dataset/transition_pickers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import dataclasses

import numpy as np
from typing_extensions import Protocol

from ..types import Float32NDArray
from .components import EpisodeBase, Transition
from .utils import (
create_zero_observation,
Expand All @@ -11,6 +14,7 @@
__all__ = [
"TransitionPickerProtocol",
"BasicTransitionPicker",
"SparseRewardTransitionPicker",
"FrameStackTransitionPicker",
"MultiStepTransitionPicker",
]
Expand Down Expand Up @@ -65,6 +69,37 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
)


class SparseRewardTransitionPicker(TransitionPickerProtocol):
r"""Sparse reward transition picker.
This class extends BasicTransitionPicker to handle special returns_to_go
calculation mainly used in AntMaze environments.
Args:
horizon_length (int): Length to repeat rewards_to_go.
step_reward (float): Immediate step reward value in sparse reward
setting.
"""

def __init__(self, horizon_length: int, step_reward: float = 0.0):
self._horizon_length = horizon_length
self._step_reward = step_reward
self._transition_picker = BasicTransitionPicker()

def __call__(self, episode: EpisodeBase, index: int) -> Transition:
transition = self._transition_picker(episode, index)
if np.all(transition.rewards_to_go == self._step_reward):
extended_rewards_to_go: Float32NDArray = np.array(
[[self._step_reward]] * self._horizon_length,
dtype=np.float32,
)
transition = dataclasses.replace(
transition,
rewards_to_go=extended_rewards_to_go,
)
return transition


class FrameStackTransitionPicker(TransitionPickerProtocol):
r"""Frame-stacking transition picker.
Expand Down
8 changes: 7 additions & 1 deletion reproductions/finetuning/cal_ql_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ def main() -> None:
parser.add_argument("--gpu", type=int)
args = parser.parse_args()

dataset, env = d3rlpy.datasets.get_d4rl(args.dataset)
dataset, env = d3rlpy.datasets.get_d4rl(
args.dataset,
transition_picker=d3rlpy.dataset.SparseRewardTransitionPicker(
horizon_length=100,
step_reward=0,
),
)

# fix seed
d3rlpy.seed(args.seed)
Expand Down
40 changes: 40 additions & 0 deletions tests/dataset/test_transition_pickers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import dataclasses

import numpy as np
import pytest

from d3rlpy.dataset import (
BasicTransitionPicker,
FrameStackTransitionPicker,
MultiStepTransitionPicker,
SparseRewardTransitionPicker,
)
from d3rlpy.types import Shape

Expand Down Expand Up @@ -185,3 +188,40 @@ def test_multi_step_transition_picker(
assert np.allclose(transition.rewards_to_go, episode.rewards[-n_steps:])
assert transition.interval == n_steps
assert transition.terminal == 1.0


@pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))])
@pytest.mark.parametrize("action_size", [2])
@pytest.mark.parametrize("length", [100])
@pytest.mark.parametrize("horizon_length", [1000])
@pytest.mark.parametrize("step_reward", [0.0])
@pytest.mark.parametrize("success", [False, True])
def test_sparse_reward_transition_picker(
observation_shape: Shape,
action_size: int,
length: int,
horizon_length: int,
step_reward: float,
success: bool,
) -> None:
episode = create_episode(
observation_shape, action_size, length, terminated=True
)
if not success:
episode = dataclasses.replace(
episode,
rewards=np.zeros_like(episode.rewards),
)

picker = SparseRewardTransitionPicker(
horizon_length=horizon_length,
step_reward=step_reward,
)

transition = picker(episode, 0)

if success:
assert np.all(transition.rewards_to_go == episode.rewards)
else:
assert transition.rewards_to_go.shape == (horizon_length, 1)
assert np.all(transition.rewards_to_go == 0)

0 comments on commit 645b49d

Please sign in to comment.