diff --git a/d3rlpy/algos/gato/base.py b/d3rlpy/algos/gato/base.py index 5b00980c..d25ec315 100644 --- a/d3rlpy/algos/gato/base.py +++ b/d3rlpy/algos/gato/base.py @@ -29,7 +29,7 @@ ) from ...metrics import evaluate_gato_with_environment from ...models import EmbeddingModuleFactory, TokenEmbeddingFactory -from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding +from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding, get_parameter from ...serializable_config import generate_dict_config_field from ...torch_utility import eval_api, train_api from ...types import GymEnv, NDArray, Observation @@ -236,7 +236,7 @@ def _append_action_embedding(self, embedding: torch.Tensor) -> None: def _append_separator_embedding(self) -> None: assert self._algo.impl - self._embeddings.append(self._algo.impl.separator_token_embedding.data) + self._embeddings.append(get_parameter(self._algo.impl.separator_token_embedding)) self._observation_positions.append(0) self._observation_masks.append(0) self._action_masks.append(0) @@ -389,6 +389,7 @@ def fit( replay_buffers=datasets, token_embeddings=self._impl.token_embeddings, separator_token_embedding=self._impl.separator_token_embedding, + prompt_probability=0.25, ) # save hyperparameters diff --git a/d3rlpy/algos/gato/dataset.py b/d3rlpy/algos/gato/dataset.py index 50b4437b..49ec0a64 100644 --- a/d3rlpy/algos/gato/dataset.py +++ b/d3rlpy/algos/gato/dataset.py @@ -1,6 +1,6 @@ import dataclasses from collections import defaultdict -from typing import DefaultDict, Dict, List, Sequence, Union +from typing import DefaultDict, Dict, List, Sequence, Union, Optional import numpy as np import torch @@ -200,6 +200,7 @@ def __call__( token_size: int, token_embeddings: Dict[str, TokenEmbedding], separator_token_embedding: SeparatorTokenEmbedding, + prompt_episode: Optional[GatoTokenEpisode] = None, ) -> GatoTrainingInputEmbedding: num_steps = token_size // episode.one_step_block_size end_step = end_step + 1 @@ -236,7 +237,7 @@ def __call__( separator_embeddings = separator_token_embedding(action_embedding) # concat observations and actions - # S = total_obs_num_tokens + action_num_tokens + # S = total_obs_num_tokens + action_num_tokens + 1 # (T, S, N) concat_embeddings = torch.cat( [ @@ -282,6 +283,23 @@ def __call__( # compute backward padding size pad_size = token_size - actual_num_steps * episode.one_step_block_size + if pad_size > 0 and prompt_episode: + prompt = self.__call__( + episode=prompt_episode, + end_step=prompt_episode.size() - 1, + token_size=pad_size, + token_embeddings=token_embeddings, + separator_token_embedding=separator_token_embedding, + ) + return GatoTrainingInputEmbedding( + embeddings=torch.cat([prompt.embeddings, flatten_embeddings], dim=0), + observation_positions=torch.cat([prompt.observation_positions, flatten_observation_positions], dim=0), + observation_masks=torch.cat([prompt.observation_masks, flatten_observation_masks], dim=0), + action_masks=torch.cat([prompt.action_masks, flatten_action_masks], dim=0), + action_tokens=torch.cat([prompt.action_tokens, flatten_action_tokens], dim=0), + masks=torch.cat([torch.zeros_like(prompt.masks), masks], dim=0), + ) + if pad_size == 0: return GatoTrainingInputEmbedding( embeddings=flatten_embeddings, @@ -362,16 +380,19 @@ class GatoReplayBuffer: _token_slicer: GatoTokenSlicer _token_embeddings: Dict[str, TokenEmbedding] _separator_token_embedding: SeparatorTokenEmbedding + _prompt_probability: float def __init__( self, replay_buffers: Sequence[ReplayBufferWithEmbeddingKeys], token_embeddings: Dict[str, TokenEmbedding], separator_token_embedding: SeparatorTokenEmbedding, + prompt_probability: float, ): self._token_slicer = GatoTokenSlicer() self._token_embeddings = token_embeddings self._separator_token_embedding = separator_token_embedding + self._prompt_probability = prompt_probability self._episodes = [] self._episodes_per_task = defaultdict(list) for replay_buffer in replay_buffers: @@ -394,12 +415,20 @@ def sample_embedding_sequence( ) -> GatoTrainingInputEmbedding: episode = self._episodes[int(np.random.randint(len(self._episodes)))] end_step = int(np.random.randint(episode.size())) + if np.random.random() < self._prompt_probability: + task_id = episode.task_id + num_episodes = len(self._episodes_per_task[task_id]) + prompt_index = np.random.randint(num_episodes) + prompt_episode = self._episodes_per_task[task_id][prompt_index] + else: + prompt_episode = None return self._token_slicer( episode=episode, end_step=end_step, token_size=length, token_embeddings=self._token_embeddings, separator_token_embedding=self._separator_token_embedding, + prompt_episode=prompt_episode, ) def sample_embedding_mini_batch( diff --git a/d3rlpy/algos/gato/torch/gato_impl.py b/d3rlpy/algos/gato/torch/gato_impl.py index 62d325d5..c54d6780 100644 --- a/d3rlpy/algos/gato/torch/gato_impl.py +++ b/d3rlpy/algos/gato/torch/gato_impl.py @@ -119,7 +119,8 @@ def compute_loss(self, batch: GatoEmbeddingMiniBatch) -> torch.Tensor: batch.action_tokens[:, 1:].reshape(-1).long(), reduction="none", ) - return (loss * batch.action_masks[:, 1:, :].reshape(-1)).mean() + masks = batch.masks[:, 1:, :] * batch.action_masks[:, 1:, :] + return (loss * masks.reshape(-1)).mean() @property def token_embeddings(self) -> Dict[str, TokenEmbedding]: diff --git a/d3rlpy/models/torch/embeddings.py b/d3rlpy/models/torch/embeddings.py index d74e4673..534d4689 100644 --- a/d3rlpy/models/torch/embeddings.py +++ b/d3rlpy/models/torch/embeddings.py @@ -4,7 +4,7 @@ from ...tokenizers import Tokenizer from ...types import Int32NDArray, NDArray -from .parameters import Parameter +from .parameters import Parameter, get_parameter __all__ = [ "TokenEmbedding", @@ -65,10 +65,6 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: assert x.ndim == 3 - assert x.shape[-1] == self._data.data.shape[0] - data = self._data.data.view(1, 1, -1) + assert x.shape[-1] == get_parameter(self._data).shape[0] + data = get_parameter(self._data).view(1, 1, -1) return torch.tile(data, [x.shape[0], 1, 1]) - - @property - def data(self) -> torch.Tensor: - return self._data.data