Skip to content

Commit

Permalink
Fix SparseRewardTransitionPicker
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed May 11, 2024
1 parent 1625361 commit 0097fda
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
11 changes: 7 additions & 4 deletions d3rlpy/dataset/transition_pickers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,25 @@ class SparseRewardTransitionPicker(TransitionPickerProtocol):
This class extends BasicTransitionPicker to handle special returns_to_go
calculation mainly used in AntMaze environments.
For the failure trajectories, this class sets the constant return value to
avoid inconsistent horizon due to time out.
Args:
horizon_length (int): Length to repeat rewards_to_go.
failure_return (int): Return value for failure trajectories.
step_reward (float): Immediate step reward value in sparse reward
setting.
"""

def __init__(self, horizon_length: int, step_reward: float = 0.0):
self._horizon_length = horizon_length
def __init__(self, failure_return: float, step_reward: float = 0.0):
self._failure_return = failure_return
self._step_reward = step_reward
self._transition_picker = BasicTransitionPicker()

def __call__(self, episode: EpisodeBase, index: int) -> Transition:
transition = self._transition_picker(episode, index)
if np.all(transition.rewards_to_go == self._step_reward):
extended_rewards_to_go: Float32NDArray = np.array(
[[self._step_reward]] * self._horizon_length,
[[self._failure_return]],
dtype=np.float32,
)
transition = dataclasses.replace(
Expand Down
2 changes: 1 addition & 1 deletion reproductions/finetuning/cal_ql_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main() -> None:

# sparse reward setup requires special treatment for failure trajectories
transition_picker = d3rlpy.dataset.SparseRewardTransitionPicker(
horizon_length=100,
failure_return=-49.5, # ((-5 / 0.01) + 5) / 10
step_reward=0,
)

Expand Down
10 changes: 5 additions & 5 deletions tests/dataset/test_transition_pickers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,14 @@ def test_multi_step_transition_picker(
@pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))])
@pytest.mark.parametrize("action_size", [2])
@pytest.mark.parametrize("length", [100])
@pytest.mark.parametrize("horizon_length", [1000])
@pytest.mark.parametrize("failure_return", [-100])
@pytest.mark.parametrize("step_reward", [0.0])
@pytest.mark.parametrize("success", [False, True])
def test_sparse_reward_transition_picker(
observation_shape: Shape,
action_size: int,
length: int,
horizon_length: int,
failure_return: int,
step_reward: float,
success: bool,
) -> None:
Expand All @@ -214,7 +214,7 @@ def test_sparse_reward_transition_picker(
)

picker = SparseRewardTransitionPicker(
horizon_length=horizon_length,
failure_return=failure_return,
step_reward=step_reward,
)

Expand All @@ -223,5 +223,5 @@ def test_sparse_reward_transition_picker(
if success:
assert np.all(transition.rewards_to_go == episode.rewards)
else:
assert transition.rewards_to_go.shape == (horizon_length, 1)
assert np.all(transition.rewards_to_go == 0)
assert transition.rewards_to_go.shape == (1, 1)
assert transition.rewards_to_go[0][0] == failure_return

0 comments on commit 0097fda

Please sign in to comment.