Skip to content

Commit

Permalink
Add return-to-go to Transition
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 4, 2023
1 parent 1a70a8a commit 7094812
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 7 deletions.
2 changes: 2 additions & 0 deletions d3rlpy/dataset/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ class Transition:
reward: Reward. This could be a multi-step discounted return.
next_observation: Observation at next timestep. This could be
observation at multi-step ahead.
return_to_go: Remaining return till the end of an episode.
terminal: Flag of environment termination.
interval: Timesteps between ``observation`` and ``next_observation``.
"""
observation: Observation # (...)
action: NDArray # (...)
reward: Float32NDArray # (1,)
next_observation: Observation # (...)
return_to_go: Float32NDArray # (1,)
terminal: float
interval: int

Expand Down
8 changes: 8 additions & 0 deletions d3rlpy/dataset/mini_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TransitionMiniBatch:
actions: Batched actions.
rewards: Batched rewards.
next_observations: Batched next observations.
returns_to_go: Batched returns-to-go.
terminals: Batched environment terminal flags.
intervals: Batched timesteps between observations and next
observations.
Expand All @@ -35,6 +36,7 @@ class TransitionMiniBatch:
next_observations: Union[
Float32NDArray, Sequence[Float32NDArray]
] # (B, ...)
returns_to_go: Float32NDArray # (B, 1)
terminals: Float32NDArray # (B, 1)
intervals: Float32NDArray # (B, 1)

Expand All @@ -47,6 +49,8 @@ def __post_init__(self) -> None:
assert check_dtype(self.rewards, np.float32)
assert check_non_1d_array(self.next_observations)
assert check_dtype(self.next_observations, np.float32)
assert check_non_1d_array(self.returns_to_go)
assert check_dtype(self.returns_to_go, np.float32)
assert check_non_1d_array(self.terminals)
assert check_dtype(self.terminals, np.float32)
assert check_non_1d_array(self.intervals)
Expand Down Expand Up @@ -76,6 +80,9 @@ def from_transitions(
next_observations = stack_observations(
[transition.next_observation for transition in transitions]
)
returns_to_go = np.stack(
[transition.return_to_go for transition in transitions], axis=0
)
terminals = np.reshape(
np.array([transition.terminal for transition in transitions]),
[-1, 1],
Expand All @@ -89,6 +96,7 @@ def from_transitions(
actions=cast_recursively(actions, np.float32),
rewards=cast_recursively(rewards, np.float32),
next_observations=cast_recursively(next_observations, np.float32),
returns_to_go=cast_recursively(returns_to_go, np.float32),
terminals=cast_recursively(terminals, np.float32),
intervals=cast_recursively(intervals, np.float32),
)
Expand Down
35 changes: 33 additions & 2 deletions d3rlpy/dataset/transition_pickers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,16 @@ class BasicTransitionPicker(TransitionPickerProtocol):
r"""Standard transition picker.
This class implements a basic transition picking.
Args:
gamma (float): Discount factor to compute return-to-go.
"""

_gamma: float

def __init__(self, gamma: float = 0.99):
self._gamma = gamma

def __call__(self, episode: EpisodeBase, index: int) -> Transition:
_validate_index(episode, index)

Expand All @@ -53,11 +61,18 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
next_observation = retrieve_observation(
episode.observations, index + 1
)

# compute return-to-go
length = episode.size() - index
cum_gammas = np.expand_dims(self._gamma ** np.arange(length), axis=1)
return_to_go = np.sum(cum_gammas * episode.rewards[index:], axis=0)

return Transition(
observation=observation,
action=episode.actions[index],
reward=episode.rewards[index],
next_observation=next_observation,
return_to_go=return_to_go,
terminal=float(is_terminal),
interval=1,
)
Expand Down Expand Up @@ -85,13 +100,16 @@ class FrameStackTransitionPicker(TransitionPickerProtocol):
transition.observation.shape == (4, 84, 84)
Args:
n_frames: Number of frames to stack.
n_frames (int): Number of frames to stack.
gamma (float): Discount factor to compute return-to-go.
"""
_n_frames: int
_gamma: float

def __init__(self, n_frames: int):
def __init__(self, n_frames: int, gamma: float = 0.99):
assert n_frames > 0
self._n_frames = n_frames
self._gamma = gamma

def __call__(self, episode: EpisodeBase, index: int) -> Transition:
_validate_index(episode, index)
Expand All @@ -106,11 +124,18 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
next_observation = stack_recent_observations(
episode.observations, index + 1, self._n_frames
)

# compute return-to-go
length = episode.size() - index
cum_gammas = np.expand_dims(self._gamma ** np.arange(length), axis=1)
return_to_go = np.sum(cum_gammas * episode.rewards[index:], axis=0)

return Transition(
observation=observation,
action=episode.actions[index],
reward=episode.rewards[index],
next_observation=next_observation,
return_to_go=return_to_go,
terminal=float(is_terminal),
interval=1,
)
Expand Down Expand Up @@ -156,6 +181,11 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
episode.observations, next_index
)

# compute return-to-go
length = episode.size() - index
cum_gammas = np.expand_dims(self._gamma ** np.arange(length), axis=1)
return_to_go = np.sum(cum_gammas * episode.rewards[index:], axis=0)

# compute multi-step return
interval = next_index - index
cum_gammas = np.expand_dims(self._gamma ** np.arange(interval), axis=1)
Expand All @@ -166,6 +196,7 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
action=episode.actions[index],
reward=ret,
next_observation=next_observation,
return_to_go=return_to_go,
terminal=float(is_terminal),
interval=interval,
)
5 changes: 5 additions & 0 deletions d3rlpy/torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class TorchMiniBatch:
actions: torch.Tensor
rewards: torch.Tensor
next_observations: torch.Tensor
returns_to_go: torch.Tensor
terminals: torch.Tensor
intervals: torch.Tensor
device: str
Expand All @@ -128,6 +129,7 @@ def from_batch(
next_observations = convert_to_torch_recursively(
batch.next_observations, device
)
returns_to_go = convert_to_torch(batch.returns_to_go, device)
terminals = convert_to_torch(batch.terminals, device)
intervals = convert_to_torch(batch.intervals, device)

Expand All @@ -143,12 +145,15 @@ def from_batch(
actions = action_scaler.transform(actions)
if reward_scaler:
rewards = reward_scaler.transform(rewards)
# NOTE: some operations might be incompatible with returns
returns_to_go = reward_scaler.transform(returns_to_go)

return TorchMiniBatch(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
returns_to_go=returns_to_go,
terminals=terminals,
intervals=intervals,
device=device,
Expand Down
1 change: 1 addition & 0 deletions tests/algos/qlearning/algo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def update_tester(
action=action,
reward=reward,
next_observation=next_observation,
return_to_go=reward,
terminal=terminal,
interval=1,
)
Expand Down
1 change: 1 addition & 0 deletions tests/dataset/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_transition(observation_size: int, action_size: int) -> None:
action=np.random.random(action_size).astype(np.float32),
reward=np.random.random(1).astype(np.float32),
next_observation=np.random.random(observation_size).astype(np.float32),
return_to_go=np.random.random(1).astype(np.float32),
terminal=0.0,
interval=1,
)
Expand Down
3 changes: 3 additions & 0 deletions tests/dataset/test_mini_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_transition_mini_batch(

ref_actions = np.array([t.action for t in transitions])
ref_rewards = np.array([t.reward for t in transitions])
ref_returns_to_go = np.array([t.return_to_go for t in transitions])
ref_terminals = np.array([[t.terminal] for t in transitions])
ref_intervals = np.array([[t.interval] for t in transitions])

Expand Down Expand Up @@ -53,10 +54,12 @@ def test_transition_mini_batch(
assert np.all(batch.next_observations == ref_next_observations)
assert batch.actions.shape == (batch_size, action_size)
assert batch.rewards.shape == (batch_size, 1)
assert batch.returns_to_go.shape == (batch_size, 1)
assert batch.terminals.shape == (batch_size, 1)
assert batch.intervals.shape == (batch_size, 1)
assert np.all(batch.actions == ref_actions)
assert np.all(batch.rewards == ref_rewards)
assert np.all(batch.returns_to_go == ref_returns_to_go)
assert np.all(batch.terminals == ref_terminals)
assert np.all(batch.intervals == ref_intervals)

Expand Down
38 changes: 33 additions & 5 deletions tests/dataset/test_transition_pickers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,39 @@

from d3rlpy.dataset import (
BasicTransitionPicker,
Episode,
FrameStackTransitionPicker,
MultiStepTransitionPicker,
)
from d3rlpy.types import Shape
from d3rlpy.types import Float32NDArray, Shape

from ..testing_utils import create_episode


def _compute_returns_to_go(episode: Episode, gamma: float) -> Float32NDArray:
ref_returns_to_go = []
for i in range(episode.size()):
ret = episode.rewards[i].copy()
for j in range(i + 1, episode.size()):
ret += (gamma ** (j - i)) * episode.rewards[j]
ref_returns_to_go.append(ret)
return np.array(ref_returns_to_go, dtype=np.float32)


@pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))])
@pytest.mark.parametrize("action_size", [2])
@pytest.mark.parametrize("length", [100])
@pytest.mark.parametrize("gamma", [0.99])
def test_basic_transition_picker(
observation_shape: Shape, action_size: int, length: int
observation_shape: Shape, action_size: int, length: int, gamma: float
) -> None:
episode = create_episode(
observation_shape, action_size, length, terminated=True
)

picker = BasicTransitionPicker()
ref_returns_to_go = _compute_returns_to_go(episode, gamma)

picker = BasicTransitionPicker(gamma=gamma)

# check transition
transition = picker(episode, 0)
Expand All @@ -40,6 +54,7 @@ def test_basic_transition_picker(
assert np.all(transition.next_observation == episode.observations[1])
assert np.all(transition.action == episode.actions[0])
assert np.all(transition.reward == episode.rewards[0])
assert np.allclose(transition.return_to_go, ref_returns_to_go[0])
assert transition.interval == 1
assert transition.terminal == 0

Expand All @@ -60,6 +75,7 @@ def test_basic_transition_picker(
assert np.all(transition.next_observation == dummy_observation)
assert np.all(transition.action == episode.actions[-1])
assert np.all(transition.reward == episode.rewards[-1])
assert np.allclose(transition.return_to_go, ref_returns_to_go[-1])
assert transition.interval == 1
assert transition.terminal == 1.0

Expand All @@ -68,14 +84,21 @@ def test_basic_transition_picker(
@pytest.mark.parametrize("action_size", [2])
@pytest.mark.parametrize("length", [100])
@pytest.mark.parametrize("n_frames", [4])
@pytest.mark.parametrize("gamma", [0.99])
def test_frame_stack_transition_picker(
observation_shape: Shape, action_size: int, length: int, n_frames: int
observation_shape: Shape,
action_size: int,
length: int,
n_frames: int,
gamma: float,
) -> None:
episode = create_episode(
observation_shape, action_size, length, terminated=True
)

picker = FrameStackTransitionPicker(n_frames)
ref_returns_to_go = _compute_returns_to_go(episode, gamma)

picker = FrameStackTransitionPicker(n_frames, gamma=gamma)

n_channels = observation_shape[0]
assert isinstance(n_channels, int)
Expand Down Expand Up @@ -105,6 +128,7 @@ def test_frame_stack_transition_picker(
assert np.all(next_obs == 0.0)
assert np.all(transition.action == episode.actions[i])
assert np.all(transition.reward == episode.rewards[i])
assert np.allclose(transition.return_to_go, ref_returns_to_go[i])
assert transition.terminal == 0.0
assert transition.interval == 1

Expand All @@ -130,6 +154,8 @@ def test_multi_step_transition_picker(
observation_shape, action_size, length, terminated=True
)

ref_returns_to_go = _compute_returns_to_go(episode, gamma)

picker = MultiStepTransitionPicker(n_steps=n_steps, gamma=gamma)

# check transition
Expand All @@ -154,6 +180,7 @@ def test_multi_step_transition_picker(
ref_reward = np.sum(gammas * np.reshape(episode.rewards[:n_steps], [-1]))
assert np.all(transition.action == episode.actions[0])
assert np.all(transition.reward == np.reshape(ref_reward, [1]))
assert np.allclose(transition.return_to_go, ref_returns_to_go[0])
assert transition.interval == n_steps
assert transition.terminal == 0

Expand All @@ -175,5 +202,6 @@ def test_multi_step_transition_picker(
assert np.all(transition.action == episode.actions[-n_steps])
ref_reward = np.sum(gammas * np.reshape(episode.rewards[-n_steps:], [-1]))
assert np.all(transition.reward == np.reshape(ref_reward, [1]))
assert np.allclose(transition.return_to_go, ref_returns_to_go[-n_steps])
assert transition.interval == n_steps
assert transition.terminal == 1.0
5 changes: 5 additions & 0 deletions tests/test_torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def test_torch_mini_batch(
action=np.random.random(action_size),
reward=np.random.random((1,)).astype(np.float32),
next_observation=np.random.random(obs_shape),
return_to_go=np.random.random((1,)).astype(np.float32),
terminal=0.0,
interval=1,
)
Expand Down Expand Up @@ -267,8 +268,12 @@ def test_torch_mini_batch(

if use_reward_scaler:
assert np.all(torch_batch.rewards.numpy() == batch.rewards + 0.3)
assert np.all(
torch_batch.returns_to_go.numpy() == batch.returns_to_go + 0.3
)
else:
assert np.all(torch_batch.rewards.numpy() == batch.rewards)
assert np.all(torch_batch.returns_to_go.numpy() == batch.returns_to_go)

assert np.all(torch_batch.terminals.numpy() == batch.terminals)
assert np.all(torch_batch.intervals.numpy() == batch.intervals)
Expand Down
1 change: 1 addition & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def create_transition(
action=action,
reward=np.random.random(1).astype(np.float32),
next_observation=next_observation,
return_to_go=np.random.random(1).astype(np.float32),
terminal=1.0 if terminated else 0.0,
interval=1,
)
Expand Down

0 comments on commit 7094812

Please sign in to comment.