diff --git a/d3rlpy/datasets.py b/d3rlpy/datasets.py index 96ee07da..f180d7bd 100644 --- a/d3rlpy/datasets.py +++ b/d3rlpy/datasets.py @@ -470,28 +470,43 @@ def get_minari( try: import minari - _dataset: minari.MinariDataset = minari.load_dataset( - env_name, download=True - ) + _dataset = minari.load_dataset(env_name, download=True) - data: Dict[str, List[NDArray]] = { - "observations": [], - "actions": [], - "rewards": [], - "terminations": [], - "truncations": [], - } + observations = [] + actions = [] + rewards = [] + terminals = [] + timeouts = [] for ep in _dataset: - for k, v in data.items(): - v.append(getattr(ep, k)) + if isinstance(ep.observations, dict): + if ( + "desired_goal" in ep.observations + and "observation" in ep.observations + ): + _observations = np.concatenate( + [ + ep.observations["observation"], + ep.observations["desired_goal"], + ], + axis=-1, + ) + else: + raise ValueError("Unsupported observation format.") + else: + _observations = ep.observations + observations.append(_observations) + actions.append(ep.actions) + rewards.append(ep.rewards) + terminals.append(ep.terminations) + timeouts.append(ep.truncations) dataset = MDPDataset( - observations=np.concatenate(data["observations"]), - actions=np.concatenate(data["actions"]), - rewards=np.concatenate(data["rewards"]), - terminals=np.concatenate(data["terminations"]), - timeouts=np.concatenate(data["truncations"]), + observations=np.concatenate(observations), + actions=np.concatenate(actions), + rewards=np.concatenate(rewards), + terminals=np.concatenate(terminals), + timeouts=np.concatenate(timeouts), transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, )