Skip to content

Commit

Permalink
Add prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 13, 2024
1 parent 0095338 commit 0da955e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
5 changes: 3 additions & 2 deletions d3rlpy/algos/gato/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from ...metrics import evaluate_gato_with_environment
from ...models import EmbeddingModuleFactory, TokenEmbeddingFactory
from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding
from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding, get_parameter
from ...serializable_config import generate_dict_config_field
from ...torch_utility import eval_api, train_api
from ...types import GymEnv, NDArray, Observation
Expand Down Expand Up @@ -236,7 +236,7 @@ def _append_action_embedding(self, embedding: torch.Tensor) -> None:

def _append_separator_embedding(self) -> None:
assert self._algo.impl
self._embeddings.append(self._algo.impl.separator_token_embedding.data)
self._embeddings.append(get_parameter(self._algo.impl.separator_token_embedding))
self._observation_positions.append(0)
self._observation_masks.append(0)
self._action_masks.append(0)
Expand Down Expand Up @@ -389,6 +389,7 @@ def fit(
replay_buffers=datasets,
token_embeddings=self._impl.token_embeddings,
separator_token_embedding=self._impl.separator_token_embedding,
prompt_probability=0.25,
)

# save hyperparameters
Expand Down
33 changes: 31 additions & 2 deletions d3rlpy/algos/gato/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict, List, Sequence, Union
from typing import DefaultDict, Dict, List, Sequence, Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -200,6 +200,7 @@ def __call__(
token_size: int,
token_embeddings: Dict[str, TokenEmbedding],
separator_token_embedding: SeparatorTokenEmbedding,
prompt_episode: Optional[GatoTokenEpisode] = None,
) -> GatoTrainingInputEmbedding:
num_steps = token_size // episode.one_step_block_size
end_step = end_step + 1
Expand Down Expand Up @@ -236,7 +237,7 @@ def __call__(
separator_embeddings = separator_token_embedding(action_embedding)

# concat observations and actions
# S = total_obs_num_tokens + action_num_tokens
# S = total_obs_num_tokens + action_num_tokens + 1
# (T, S, N)
concat_embeddings = torch.cat(
[
Expand Down Expand Up @@ -282,6 +283,23 @@ def __call__(
# compute backward padding size
pad_size = token_size - actual_num_steps * episode.one_step_block_size

if pad_size > 0 and prompt_episode:
prompt = self.__call__(
episode=prompt_episode,
end_step=prompt_episode.size() - 1,
token_size=pad_size,
token_embeddings=token_embeddings,
separator_token_embedding=separator_token_embedding,
)
return GatoTrainingInputEmbedding(
embeddings=torch.cat([prompt.embeddings, flatten_embeddings], dim=0),
observation_positions=torch.cat([prompt.observation_positions, flatten_observation_positions], dim=0),
observation_masks=torch.cat([prompt.observation_masks, flatten_observation_masks], dim=0),
action_masks=torch.cat([prompt.action_masks, flatten_action_masks], dim=0),
action_tokens=torch.cat([prompt.action_tokens, flatten_action_tokens], dim=0),
masks=torch.cat([torch.zeros_like(prompt.masks), masks], dim=0),
)

if pad_size == 0:
return GatoTrainingInputEmbedding(
embeddings=flatten_embeddings,
Expand Down Expand Up @@ -362,16 +380,19 @@ class GatoReplayBuffer:
_token_slicer: GatoTokenSlicer
_token_embeddings: Dict[str, TokenEmbedding]
_separator_token_embedding: SeparatorTokenEmbedding
_prompt_probability: float

def __init__(
self,
replay_buffers: Sequence[ReplayBufferWithEmbeddingKeys],
token_embeddings: Dict[str, TokenEmbedding],
separator_token_embedding: SeparatorTokenEmbedding,
prompt_probability: float,
):
self._token_slicer = GatoTokenSlicer()
self._token_embeddings = token_embeddings
self._separator_token_embedding = separator_token_embedding
self._prompt_probability = prompt_probability
self._episodes = []
self._episodes_per_task = defaultdict(list)
for replay_buffer in replay_buffers:
Expand All @@ -394,12 +415,20 @@ def sample_embedding_sequence(
) -> GatoTrainingInputEmbedding:
episode = self._episodes[int(np.random.randint(len(self._episodes)))]
end_step = int(np.random.randint(episode.size()))
if np.random.random() < self._prompt_probability:
task_id = episode.task_id
num_episodes = len(self._episodes_per_task[task_id])
prompt_index = np.random.randint(num_episodes)
prompt_episode = self._episodes_per_task[task_id][prompt_index]
else:
prompt_episode = None
return self._token_slicer(
episode=episode,
end_step=end_step,
token_size=length,
token_embeddings=self._token_embeddings,
separator_token_embedding=self._separator_token_embedding,
prompt_episode=prompt_episode,
)

def sample_embedding_mini_batch(
Expand Down
3 changes: 2 additions & 1 deletion d3rlpy/algos/gato/torch/gato_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def compute_loss(self, batch: GatoEmbeddingMiniBatch) -> torch.Tensor:
batch.action_tokens[:, 1:].reshape(-1).long(),
reduction="none",
)
return (loss * batch.action_masks[:, 1:, :].reshape(-1)).mean()
masks = batch.masks[:, 1:, :] * batch.action_masks[:, 1:, :]
return (loss * masks.reshape(-1)).mean()

@property
def token_embeddings(self) -> Dict[str, TokenEmbedding]:
Expand Down
10 changes: 3 additions & 7 deletions d3rlpy/models/torch/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...tokenizers import Tokenizer
from ...types import Int32NDArray, NDArray
from .parameters import Parameter
from .parameters import Parameter, get_parameter

__all__ = [
"TokenEmbedding",
Expand Down Expand Up @@ -65,10 +65,6 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.ndim == 3
assert x.shape[-1] == self._data.data.shape[0]
data = self._data.data.view(1, 1, -1)
assert x.shape[-1] == get_parameter(self._data).shape[0]
data = get_parameter(self._data).view(1, 1, -1)
return torch.tile(data, [x.shape[0], 1, 1])

@property
def data(self) -> torch.Tensor:
return self._data.data

0 comments on commit 0da955e

Please sign in to comment.