Skip to content

Commit

Permalink
Add PrecisionScaler
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 24, 2024
1 parent 0da955e commit 070b070
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
2 changes: 2 additions & 0 deletions d3rlpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
envs,
logging,
metrics,
mixed_precision,
models,
notebook_utils,
ope,
Expand All @@ -35,6 +36,7 @@
"ope",
"preprocessing",
"tokenizers",
"mixed_precision",
"__version__",
"load_learnable",
"ActionSpace",
Expand Down
33 changes: 25 additions & 8 deletions d3rlpy/algos/gato/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@
LoggerAdapterFactory,
)
from ...metrics import evaluate_gato_with_environment
from ...mixed_precision import NoCastPrecisionScaler, PrecisionScaler
from ...models import EmbeddingModuleFactory, TokenEmbeddingFactory
from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding, get_parameter
from ...models.torch import (
SeparatorTokenEmbedding,
TokenEmbedding,
get_parameter,
)
from ...serializable_config import generate_dict_config_field
from ...torch_utility import eval_api, train_api
from ...types import GymEnv, NDArray, Observation
Expand Down Expand Up @@ -62,13 +67,19 @@ def inner_predict(self, inpt: GatoInputEmbedding) -> int:

@train_api
def update(
self, batch: GatoEmbeddingMiniBatch, grad_step: int
self,
batch: GatoEmbeddingMiniBatch,
grad_step: int,
precision_scaler: PrecisionScaler,
) -> Dict[str, float]:
return self.inner_update(batch, grad_step)
return self.inner_update(batch, grad_step, precision_scaler)

@abstractmethod
def inner_update(
self, batch: GatoEmbeddingMiniBatch, grad_step: int
self,
batch: GatoEmbeddingMiniBatch,
grad_step: int,
precision_scaler: PrecisionScaler,
) -> Dict[str, float]:
pass

Expand Down Expand Up @@ -236,7 +247,9 @@ def _append_action_embedding(self, embedding: torch.Tensor) -> None:

def _append_separator_embedding(self) -> None:
assert self._algo.impl
self._embeddings.append(get_parameter(self._algo.impl.separator_token_embedding))
self._embeddings.append(
get_parameter(self._algo.impl.separator_token_embedding)
)
self._observation_positions.append(0)
self._observation_masks.append(0)
self._action_masks.append(0)
Expand Down Expand Up @@ -341,6 +354,7 @@ def fit(
evaluators: Optional[Dict[str, GatoEnvironmentEvaluator]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
enable_ddp: bool = False,
precision_scaler: PrecisionScaler = NoCastPrecisionScaler(),
) -> None:
"""Trains with given dataset.
Expand All @@ -359,6 +373,7 @@ def fit(
callback: Callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
enable_ddp: Flag to wrap models with DataDistributedParallel.
precision_scaler: Precision scaler for mixed precision training.
"""

# setup logger
Expand Down Expand Up @@ -419,7 +434,7 @@ def fit(

# update parameters
with logger.measure_time("algorithm_update"):
loss = self.update(batch)
loss = self.update(batch, precision_scaler)

# record metrics
for name, val in loss.items():
Expand Down Expand Up @@ -453,7 +468,9 @@ def fit(

logger.close()

def update(self, batch: GatoEmbeddingMiniBatch) -> Dict[str, float]:
def update(
self, batch: GatoEmbeddingMiniBatch, precision_scaler: PrecisionScaler
) -> Dict[str, float]:
"""Update parameters with mini-batch of data.
Args:
Expand All @@ -463,6 +480,6 @@ def update(self, batch: GatoEmbeddingMiniBatch) -> Dict[str, float]:
Dictionary of metrics.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
loss = self._impl.update(batch, self._grad_step)
loss = self._impl.update(batch, self._grad_step, precision_scaler)
self._grad_step += 1
return loss
13 changes: 9 additions & 4 deletions d3rlpy/algos/gato/torch/gato_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn
from torch.optim import Optimizer

from ....mixed_precision import PrecisionScaler
from ....models.torch import (
GatoTransformer,
SeparatorTokenEmbedding,
Expand Down Expand Up @@ -71,11 +72,15 @@ def inner_predict(self, inpt: GatoInputEmbedding) -> int:
return int(np.argmax(logits[0][-1].cpu().detach().numpy()))

def inner_update(
self, batch: GatoEmbeddingMiniBatch, grad_step: int
self,
batch: GatoEmbeddingMiniBatch,
grad_step: int,
precision_scaler: PrecisionScaler,
) -> Dict[str, float]:
self._modules.optim.zero_grad()
loss = self.compute_loss(batch)
loss.backward()
with precision_scaler.autocast():
loss = self.compute_loss(batch)
precision_scaler.scale_and_backward(self._modules.optim, loss)

torch.nn.utils.clip_grad_norm_(
list(self._modules.transformer.parameters())
Expand All @@ -84,7 +89,7 @@ def inner_update(
self._clip_grad_norm,
)

self._modules.optim.step()
precision_scaler.step(self._modules.optim)

# schedule learning rate
# linear warmup
Expand Down

0 comments on commit 070b070

Please sign in to comment.