diff --git a/d3rlpy/dataset/components.py b/d3rlpy/dataset/components.py index bd424f0c..1386a6c1 100644 --- a/d3rlpy/dataset/components.py +++ b/d3rlpy/dataset/components.py @@ -63,6 +63,7 @@ class Transition: reward: Reward. This could be a multi-step discounted return. next_observation: Observation at next timestep. This could be observation at multi-step ahead. + return_to_go: Remaining return till the end of an episode. terminal: Flag of environment termination. interval: Timesteps between ``observation`` and ``next_observation``. """ @@ -70,6 +71,7 @@ class Transition: action: NDArray # (...) reward: Float32NDArray # (1,) next_observation: Observation # (...) + return_to_go: Float32NDArray # (1,) terminal: float interval: int diff --git a/d3rlpy/dataset/mini_batch.py b/d3rlpy/dataset/mini_batch.py index f8d2654f..965281b4 100644 --- a/d3rlpy/dataset/mini_batch.py +++ b/d3rlpy/dataset/mini_batch.py @@ -25,6 +25,7 @@ class TransitionMiniBatch: actions: Batched actions. rewards: Batched rewards. next_observations: Batched next observations. + returns_to_go: Batched returns-to-go. terminals: Batched environment terminal flags. intervals: Batched timesteps between observations and next observations. @@ -35,6 +36,7 @@ class TransitionMiniBatch: next_observations: Union[ Float32NDArray, Sequence[Float32NDArray] ] # (B, ...) + returns_to_go: Float32NDArray # (B, 1) terminals: Float32NDArray # (B, 1) intervals: Float32NDArray # (B, 1) @@ -47,6 +49,8 @@ def __post_init__(self) -> None: assert check_dtype(self.rewards, np.float32) assert check_non_1d_array(self.next_observations) assert check_dtype(self.next_observations, np.float32) + assert check_non_1d_array(self.returns_to_go) + assert check_dtype(self.returns_to_go, np.float32) assert check_non_1d_array(self.terminals) assert check_dtype(self.terminals, np.float32) assert check_non_1d_array(self.intervals) @@ -76,6 +80,9 @@ def from_transitions( next_observations = stack_observations( [transition.next_observation for transition in transitions] ) + returns_to_go = np.stack( + [transition.return_to_go for transition in transitions], axis=0 + ) terminals = np.reshape( np.array([transition.terminal for transition in transitions]), [-1, 1], @@ -89,6 +96,7 @@ def from_transitions( actions=cast_recursively(actions, np.float32), rewards=cast_recursively(rewards, np.float32), next_observations=cast_recursively(next_observations, np.float32), + returns_to_go=cast_recursively(returns_to_go, np.float32), terminals=cast_recursively(terminals, np.float32), intervals=cast_recursively(intervals, np.float32), ) diff --git a/d3rlpy/dataset/transition_pickers.py b/d3rlpy/dataset/transition_pickers.py index d174952b..75bec9b0 100644 --- a/d3rlpy/dataset/transition_pickers.py +++ b/d3rlpy/dataset/transition_pickers.py @@ -40,8 +40,16 @@ class BasicTransitionPicker(TransitionPickerProtocol): r"""Standard transition picker. This class implements a basic transition picking. + + Args: + gamma (float): Discount factor to compute return-to-go. """ + _gamma: float + + def __init__(self, gamma: float = 0.99): + self._gamma = gamma + def __call__(self, episode: EpisodeBase, index: int) -> Transition: _validate_index(episode, index) @@ -53,11 +61,18 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition: next_observation = retrieve_observation( episode.observations, index + 1 ) + + # compute return-to-go + length = episode.size() - index + cum_gammas = np.expand_dims(self._gamma ** np.arange(length), axis=1) + return_to_go = np.sum(cum_gammas * episode.rewards[index:], axis=0) + return Transition( observation=observation, action=episode.actions[index], reward=episode.rewards[index], next_observation=next_observation, + return_to_go=return_to_go, terminal=float(is_terminal), interval=1, ) @@ -85,13 +100,16 @@ class FrameStackTransitionPicker(TransitionPickerProtocol): transition.observation.shape == (4, 84, 84) Args: - n_frames: Number of frames to stack. + n_frames (int): Number of frames to stack. + gamma (float): Discount factor to compute return-to-go. """ _n_frames: int + _gamma: float - def __init__(self, n_frames: int): + def __init__(self, n_frames: int, gamma: float = 0.99): assert n_frames > 0 self._n_frames = n_frames + self._gamma = gamma def __call__(self, episode: EpisodeBase, index: int) -> Transition: _validate_index(episode, index) @@ -106,11 +124,18 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition: next_observation = stack_recent_observations( episode.observations, index + 1, self._n_frames ) + + # compute return-to-go + length = episode.size() - index + cum_gammas = np.expand_dims(self._gamma ** np.arange(length), axis=1) + return_to_go = np.sum(cum_gammas * episode.rewards[index:], axis=0) + return Transition( observation=observation, action=episode.actions[index], reward=episode.rewards[index], next_observation=next_observation, + return_to_go=return_to_go, terminal=float(is_terminal), interval=1, ) @@ -156,6 +181,11 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition: episode.observations, next_index ) + # compute return-to-go + length = episode.size() - index + cum_gammas = np.expand_dims(self._gamma ** np.arange(length), axis=1) + return_to_go = np.sum(cum_gammas * episode.rewards[index:], axis=0) + # compute multi-step return interval = next_index - index cum_gammas = np.expand_dims(self._gamma ** np.arange(interval), axis=1) @@ -166,6 +196,7 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition: action=episode.actions[index], reward=ret, next_observation=next_observation, + return_to_go=return_to_go, terminal=float(is_terminal), interval=interval, ) diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index f9ebef5a..505fd4e1 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -107,6 +107,7 @@ class TorchMiniBatch: actions: torch.Tensor rewards: torch.Tensor next_observations: torch.Tensor + returns_to_go: torch.Tensor terminals: torch.Tensor intervals: torch.Tensor device: str @@ -128,6 +129,7 @@ def from_batch( next_observations = convert_to_torch_recursively( batch.next_observations, device ) + returns_to_go = convert_to_torch(batch.returns_to_go, device) terminals = convert_to_torch(batch.terminals, device) intervals = convert_to_torch(batch.intervals, device) @@ -143,12 +145,15 @@ def from_batch( actions = action_scaler.transform(actions) if reward_scaler: rewards = reward_scaler.transform(rewards) + # NOTE: some operations might be incompatible with returns + returns_to_go = reward_scaler.transform(returns_to_go) return TorchMiniBatch( observations=observations, actions=actions, rewards=rewards, next_observations=next_observations, + returns_to_go=returns_to_go, terminals=terminals, intervals=intervals, device=device, diff --git a/tests/algos/qlearning/algo_test.py b/tests/algos/qlearning/algo_test.py index f050881f..ae151e61 100644 --- a/tests/algos/qlearning/algo_test.py +++ b/tests/algos/qlearning/algo_test.py @@ -295,6 +295,7 @@ def update_tester( action=action, reward=reward, next_observation=next_observation, + return_to_go=reward, terminal=terminal, interval=1, ) diff --git a/tests/dataset/test_components.py b/tests/dataset/test_components.py index 71bb1db1..39df5c3d 100644 --- a/tests/dataset/test_components.py +++ b/tests/dataset/test_components.py @@ -24,6 +24,7 @@ def test_transition(observation_size: int, action_size: int) -> None: action=np.random.random(action_size).astype(np.float32), reward=np.random.random(1).astype(np.float32), next_observation=np.random.random(observation_size).astype(np.float32), + return_to_go=np.random.random(1).astype(np.float32), terminal=0.0, interval=1, ) diff --git a/tests/dataset/test_mini_batch.py b/tests/dataset/test_mini_batch.py index 2425c18c..5e9d68b0 100644 --- a/tests/dataset/test_mini_batch.py +++ b/tests/dataset/test_mini_batch.py @@ -26,6 +26,7 @@ def test_transition_mini_batch( ref_actions = np.array([t.action for t in transitions]) ref_rewards = np.array([t.reward for t in transitions]) + ref_returns_to_go = np.array([t.return_to_go for t in transitions]) ref_terminals = np.array([[t.terminal] for t in transitions]) ref_intervals = np.array([[t.interval] for t in transitions]) @@ -53,10 +54,12 @@ def test_transition_mini_batch( assert np.all(batch.next_observations == ref_next_observations) assert batch.actions.shape == (batch_size, action_size) assert batch.rewards.shape == (batch_size, 1) + assert batch.returns_to_go.shape == (batch_size, 1) assert batch.terminals.shape == (batch_size, 1) assert batch.intervals.shape == (batch_size, 1) assert np.all(batch.actions == ref_actions) assert np.all(batch.rewards == ref_rewards) + assert np.all(batch.returns_to_go == ref_returns_to_go) assert np.all(batch.terminals == ref_terminals) assert np.all(batch.intervals == ref_intervals) diff --git a/tests/dataset/test_transition_pickers.py b/tests/dataset/test_transition_pickers.py index 1d63f39b..d12af608 100644 --- a/tests/dataset/test_transition_pickers.py +++ b/tests/dataset/test_transition_pickers.py @@ -3,25 +3,39 @@ from d3rlpy.dataset import ( BasicTransitionPicker, + Episode, FrameStackTransitionPicker, MultiStepTransitionPicker, ) -from d3rlpy.types import Shape +from d3rlpy.types import Float32NDArray, Shape from ..testing_utils import create_episode +def _compute_returns_to_go(episode: Episode, gamma: float) -> Float32NDArray: + ref_returns_to_go = [] + for i in range(episode.size()): + ret = episode.rewards[i].copy() + for j in range(i + 1, episode.size()): + ret += (gamma ** (j - i)) * episode.rewards[j] + ref_returns_to_go.append(ret) + return np.array(ref_returns_to_go, dtype=np.float32) + + @pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))]) @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("length", [100]) +@pytest.mark.parametrize("gamma", [0.99]) def test_basic_transition_picker( - observation_shape: Shape, action_size: int, length: int + observation_shape: Shape, action_size: int, length: int, gamma: float ) -> None: episode = create_episode( observation_shape, action_size, length, terminated=True ) - picker = BasicTransitionPicker() + ref_returns_to_go = _compute_returns_to_go(episode, gamma) + + picker = BasicTransitionPicker(gamma=gamma) # check transition transition = picker(episode, 0) @@ -40,6 +54,7 @@ def test_basic_transition_picker( assert np.all(transition.next_observation == episode.observations[1]) assert np.all(transition.action == episode.actions[0]) assert np.all(transition.reward == episode.rewards[0]) + assert np.allclose(transition.return_to_go, ref_returns_to_go[0]) assert transition.interval == 1 assert transition.terminal == 0 @@ -60,6 +75,7 @@ def test_basic_transition_picker( assert np.all(transition.next_observation == dummy_observation) assert np.all(transition.action == episode.actions[-1]) assert np.all(transition.reward == episode.rewards[-1]) + assert np.allclose(transition.return_to_go, ref_returns_to_go[-1]) assert transition.interval == 1 assert transition.terminal == 1.0 @@ -68,14 +84,21 @@ def test_basic_transition_picker( @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("length", [100]) @pytest.mark.parametrize("n_frames", [4]) +@pytest.mark.parametrize("gamma", [0.99]) def test_frame_stack_transition_picker( - observation_shape: Shape, action_size: int, length: int, n_frames: int + observation_shape: Shape, + action_size: int, + length: int, + n_frames: int, + gamma: float, ) -> None: episode = create_episode( observation_shape, action_size, length, terminated=True ) - picker = FrameStackTransitionPicker(n_frames) + ref_returns_to_go = _compute_returns_to_go(episode, gamma) + + picker = FrameStackTransitionPicker(n_frames, gamma=gamma) n_channels = observation_shape[0] assert isinstance(n_channels, int) @@ -105,6 +128,7 @@ def test_frame_stack_transition_picker( assert np.all(next_obs == 0.0) assert np.all(transition.action == episode.actions[i]) assert np.all(transition.reward == episode.rewards[i]) + assert np.allclose(transition.return_to_go, ref_returns_to_go[i]) assert transition.terminal == 0.0 assert transition.interval == 1 @@ -130,6 +154,8 @@ def test_multi_step_transition_picker( observation_shape, action_size, length, terminated=True ) + ref_returns_to_go = _compute_returns_to_go(episode, gamma) + picker = MultiStepTransitionPicker(n_steps=n_steps, gamma=gamma) # check transition @@ -154,6 +180,7 @@ def test_multi_step_transition_picker( ref_reward = np.sum(gammas * np.reshape(episode.rewards[:n_steps], [-1])) assert np.all(transition.action == episode.actions[0]) assert np.all(transition.reward == np.reshape(ref_reward, [1])) + assert np.allclose(transition.return_to_go, ref_returns_to_go[0]) assert transition.interval == n_steps assert transition.terminal == 0 @@ -175,5 +202,6 @@ def test_multi_step_transition_picker( assert np.all(transition.action == episode.actions[-n_steps]) ref_reward = np.sum(gammas * np.reshape(episode.rewards[-n_steps:], [-1])) assert np.all(transition.reward == np.reshape(ref_reward, [1])) + assert np.allclose(transition.return_to_go, ref_returns_to_go[-n_steps]) assert transition.interval == n_steps assert transition.terminal == 1.0 diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index bcde368c..d92e6ff1 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -214,6 +214,7 @@ def test_torch_mini_batch( action=np.random.random(action_size), reward=np.random.random((1,)).astype(np.float32), next_observation=np.random.random(obs_shape), + return_to_go=np.random.random((1,)).astype(np.float32), terminal=0.0, interval=1, ) @@ -267,8 +268,12 @@ def test_torch_mini_batch( if use_reward_scaler: assert np.all(torch_batch.rewards.numpy() == batch.rewards + 0.3) + assert np.all( + torch_batch.returns_to_go.numpy() == batch.returns_to_go + 0.3 + ) else: assert np.all(torch_batch.rewards.numpy() == batch.rewards) + assert np.all(torch_batch.returns_to_go.numpy() == batch.returns_to_go) assert np.all(torch_batch.terminals.numpy() == batch.terminals) assert np.all(torch_batch.intervals.numpy() == batch.intervals) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 3e55c03a..17d68b12 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -137,6 +137,7 @@ def create_transition( action=action, reward=np.random.random(1).astype(np.float32), next_observation=next_observation, + return_to_go=np.random.random(1).astype(np.float32), terminal=1.0 if terminated else 0.0, interval=1, )