diff --git a/d3rlpy/dataset/mini_batch.py b/d3rlpy/dataset/mini_batch.py index c1ced741..f8d2654f 100644 --- a/d3rlpy/dataset/mini_batch.py +++ b/d3rlpy/dataset/mini_batch.py @@ -81,7 +81,7 @@ def from_transitions( [-1, 1], ) intervals = np.reshape( - np.array([transition.terminal for transition in transitions]), + np.array([transition.interval for transition in transitions]), [-1, 1], ) return TransitionMiniBatch( diff --git a/tests/dataset/test_mini_batch.py b/tests/dataset/test_mini_batch.py index bd7f66f2..2425c18c 100644 --- a/tests/dataset/test_mini_batch.py +++ b/tests/dataset/test_mini_batch.py @@ -15,25 +15,50 @@ def test_transition_mini_batch( ) -> None: transitions = [] for _ in range(batch_size): - transition = create_transition(observation_shape, action_size) + transition = create_transition( + observation_shape, + action_size, + terminated=bool(np.random.randint(2)), + ) transitions.append(transition) batch = TransitionMiniBatch.from_transitions(transitions) + ref_actions = np.array([t.action for t in transitions]) + ref_rewards = np.array([t.reward for t in transitions]) + ref_terminals = np.array([[t.terminal] for t in transitions]) + ref_intervals = np.array([[t.interval] for t in transitions]) + if isinstance(observation_shape[0], tuple): for i, shape in enumerate(observation_shape): + ref_observations = np.array([t.observation[i] for t in transitions]) + ref_next_observations = np.array( + [t.next_observation[i] for t in transitions] + ) assert isinstance(shape, tuple) assert batch.observations[i].shape == (batch_size, *shape) assert batch.next_observations[i].shape == (batch_size, *shape) + assert np.all(batch.observations[i] == ref_observations) + assert np.all(batch.next_observations[i] == ref_next_observations) else: + ref_observations = np.array([t.observation for t in transitions]) + ref_next_observations = np.array( + [t.next_observation for t in transitions] + ) assert isinstance(batch.observations, np.ndarray) assert isinstance(batch.next_observations, np.ndarray) assert batch.observations.shape == (batch_size, *observation_shape) assert batch.next_observations.shape == (batch_size, *observation_shape) + assert np.all(batch.observations == ref_observations) + assert np.all(batch.next_observations == ref_next_observations) assert batch.actions.shape == (batch_size, action_size) assert batch.rewards.shape == (batch_size, 1) assert batch.terminals.shape == (batch_size, 1) assert batch.intervals.shape == (batch_size, 1) + assert np.all(batch.actions == ref_actions) + assert np.all(batch.rewards == ref_rewards) + assert np.all(batch.terminals == ref_terminals) + assert np.all(batch.intervals == ref_intervals) @pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))]) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 39be9a01..3e55c03a 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -112,19 +112,25 @@ def create_transition( observation: Observation next_observation: Observation if isinstance(observation_shape[0], (list, tuple)): - observation = [np.random.random(shape) for shape in observation_shape] + observation = [ + np.random.random(shape).astype(np.float32) + for shape in observation_shape + ] next_observation = [ - np.random.random(shape) for shape in observation_shape + np.random.random(shape).astype(np.float32) + for shape in observation_shape ] else: - observation = np.random.random(observation_shape) - next_observation = np.random.random(observation_shape) + observation = np.random.random(observation_shape).astype(np.float32) + next_observation = np.random.random(observation_shape).astype( + np.float32 + ) action: NDArray if discrete_action: action = np.random.randint(action_size, size=(1,)) else: - action = np.random.random(action_size) + action = np.random.random(action_size).astype(np.float32) return Transition( observation=observation,