From c5d50aa33ff636c0293b0791f52eeb6f3835f4ff Mon Sep 17 00:00:00 2001 From: takuseno Date: Tue, 10 Oct 2023 00:18:37 +0900 Subject: [PATCH] Make position_encoding_type Enum --- d3rlpy/__init__.py | 20 +++++- .../algos/transformer/decision_transformer.py | 14 ++--- d3rlpy/constants.py | 6 ++ d3rlpy/logging/file_adapter.py | 3 + d3rlpy/models/builders.py | 61 +++++++++++-------- d3rlpy/models/torch/transformers.py | 1 + reproductions/offline/decision_transformer.py | 1 + .../offline/discrete_decision_transformer.py | 1 + tests/models/test_builders.py | 15 +++-- 9 files changed, 84 insertions(+), 38 deletions(-) diff --git a/d3rlpy/__init__.py b/d3rlpy/__init__.py index 8449db80..23868cd4 100644 --- a/d3rlpy/__init__.py +++ b/d3rlpy/__init__.py @@ -17,9 +17,27 @@ ) from ._version import __version__ from .base import load_learnable -from .constants import ActionSpace +from .constants import ActionSpace, PositionEncodingType from .healthcheck import run_healthcheck +__all__ = [ + "algos", + "dataset", + "datasets", + "envs", + "logging", + "metrics", + "models", + "notebook_utils", + "ope", + "preprocessing", + "__version__", + "load_learnable", + "ActionSpace", + "PositionEncodingType", + "seed", +] + def seed(n: int) -> None: """Sets random seed value. diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index 1e3a7586..c5988fe9 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -3,7 +3,7 @@ import torch from ...base import DeviceArg, register_learnable -from ...constants import ActionSpace +from ...constants import ActionSpace, PositionEncodingType from ...dataset import Shape from ...models import ( EncoderFactory, @@ -61,8 +61,8 @@ class DecisionTransformerConfig(TransformerConfig): resid_dropout (float): Dropout probability for residual connection. embed_dropout (float): Dropout probability for embeddings. activation_type (str): Type of activation function. - position_encoding_type (str): Type of positional encoding - (``simple`` or ``global``). + position_encoding_type (d3rlpy.PositionEncodingType): + Type of positional encoding (``SIMPLE`` or ``GLOBAL``). warmup_steps (int): Warmup steps for learning rate scheduler. clip_grad_norm (float): Norm of gradient clipping. compile (bool): (experimental) Flag to enable JIT compilation. @@ -78,7 +78,7 @@ class DecisionTransformerConfig(TransformerConfig): resid_dropout: float = 0.1 embed_dropout: float = 0.1 activation_type: str = "relu" - position_encoding_type: str = "simple" + position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE warmup_steps: int = 10000 clip_grad_norm: float = 0.25 compile: bool = False @@ -172,8 +172,8 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): activation_type (str): Type of activation function. embed_activation_type (str): Type of activation function applied to embeddings. - position_encoding_type (str): Type of positional encoding - (``simple`` or ``global``). + position_encoding_type (d3rlpy.PositionEncodingType): + Type of positional encoding (``SIMPLE`` or ``GLOBAL``). warmup_tokens (int): Number of tokens to warmup learning rate scheduler. final_tokens (int): Final number of tokens for learning rate scheduler. clip_grad_norm (float): Norm of gradient clipping. @@ -191,7 +191,7 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): embed_dropout: float = 0.1 activation_type: str = "gelu" embed_activation_type: str = "tanh" - position_encoding_type: str = "global" + position_encoding_type: PositionEncodingType = PositionEncodingType.GLOBAL warmup_tokens: int = 10240 final_tokens: int = 30000000 clip_grad_norm: float = 1.0 diff --git a/d3rlpy/constants.py b/d3rlpy/constants.py index 43f15644..9b233760 100644 --- a/d3rlpy/constants.py +++ b/d3rlpy/constants.py @@ -8,6 +8,7 @@ "DISCRETE_ACTION_SPACE_MISMATCH_ERROR", "CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR", "ActionSpace", + "PositionEncodingType", ] IMPL_NOT_INITIALIZED_ERROR = ( @@ -41,3 +42,8 @@ class ActionSpace(Enum): CONTINUOUS = 1 DISCRETE = 2 BOTH = 3 + + +class PositionEncodingType(Enum): + SIMPLE = "simple" + GLOBAL = "global" diff --git a/d3rlpy/logging/file_adapter.py b/d3rlpy/logging/file_adapter.py index f5cbc338..8b9f34e0 100644 --- a/d3rlpy/logging/file_adapter.py +++ b/d3rlpy/logging/file_adapter.py @@ -1,5 +1,6 @@ import json import os +from enum import Enum, IntEnum from typing import Any, Dict import numpy as np @@ -17,6 +18,8 @@ def default_json_encoder(obj: Any) -> Any: return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() + elif isinstance(obj, (Enum, IntEnum)): + return obj.value raise ValueError(f"invalid object type: {type(obj)}") diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index 48b183b2..266706bd 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -3,6 +3,7 @@ import torch from torch import nn +from ..constants import PositionEncodingType from ..dataset import Shape from .encoders import EncoderFactory from .q_functions import QFunctionFactory @@ -18,6 +19,7 @@ GlobalPositionEncoding, NormalPolicy, Parameter, + PositionEncoding, SimplePositionEncoding, VAEDecoder, VAEEncoder, @@ -247,6 +249,25 @@ def create_parameter( return parameter +def _create_position_encoding( + position_encoding_type: PositionEncodingType, + embed_dim: int, + max_timestep: int, + context_size: int, +) -> PositionEncoding: + if position_encoding_type == PositionEncodingType.SIMPLE: + position_encoding = SimplePositionEncoding(embed_dim, max_timestep + 1) + elif position_encoding_type == PositionEncodingType.GLOBAL: + position_encoding = GlobalPositionEncoding( + embed_dim, max_timestep + 1, context_size + ) + else: + raise ValueError( + f"invalid position_encoding_type: {position_encoding_type}" + ) + return position_encoding + + def create_continuous_decision_transformer( observation_shape: Shape, action_size: int, @@ -259,24 +280,18 @@ def create_continuous_decision_transformer( resid_dropout: float, embed_dropout: float, activation_type: str, - position_encoding_type: str, + position_encoding_type: PositionEncodingType, device: str, ) -> ContinuousDecisionTransformer: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder) - if position_encoding_type == "simple": - position_encoding = SimplePositionEncoding( - hidden_size, max_timestep + 1 - ) - elif position_encoding_type == "global": - position_encoding = GlobalPositionEncoding( - hidden_size, max_timestep + 1, context_size - ) - else: - raise ValueError( - f"invalid position_encoding_type: {position_encoding_type}" - ) + position_encoding = _create_position_encoding( + position_encoding_type=position_encoding_type, + embed_dim=hidden_size, + max_timestep=max_timestep + 1, + context_size=context_size, + ) transformer = ContinuousDecisionTransformer( encoder=encoder, @@ -308,24 +323,18 @@ def create_discrete_decision_transformer( embed_dropout: float, activation_type: str, embed_activation_type: str, - position_encoding_type: str, + position_encoding_type: PositionEncodingType, device: str, ) -> DiscreteDecisionTransformer: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder) - if position_encoding_type == "simple": - position_encoding = SimplePositionEncoding( - hidden_size, max_timestep + 1 - ) - elif position_encoding_type == "global": - position_encoding = GlobalPositionEncoding( - hidden_size, max_timestep + 1, context_size - ) - else: - raise ValueError( - f"invalid position_encoding_type: {position_encoding_type}" - ) + position_encoding = _create_position_encoding( + position_encoding_type=position_encoding_type, + embed_dim=hidden_size, + max_timestep=max_timestep + 1, + context_size=context_size, + ) transformer = DiscreteDecisionTransformer( encoder=encoder, diff --git a/d3rlpy/models/torch/transformers.py b/d3rlpy/models/torch/transformers.py index 7e28d89e..c010a6fb 100644 --- a/d3rlpy/models/torch/transformers.py +++ b/d3rlpy/models/torch/transformers.py @@ -12,6 +12,7 @@ __all__ = [ "ContinuousDecisionTransformer", "DiscreteDecisionTransformer", + "PositionEncoding", "SimplePositionEncoding", "GlobalPositionEncoding", ] diff --git a/reproductions/offline/decision_transformer.py b/reproductions/offline/decision_transformer.py index 5d96f548..9452dd7e 100644 --- a/reproductions/offline/decision_transformer.py +++ b/reproductions/offline/decision_transformer.py @@ -35,6 +35,7 @@ def main() -> None: ), observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), reward_scaler=d3rlpy.preprocessing.MultiplyRewardScaler(0.001), + position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE, context_size=20, num_heads=1, num_layers=3, diff --git a/reproductions/offline/discrete_decision_transformer.py b/reproductions/offline/discrete_decision_transformer.py index 15175fcb..8e015af8 100644 --- a/reproductions/offline/discrete_decision_transformer.py +++ b/reproductions/offline/discrete_decision_transformer.py @@ -69,6 +69,7 @@ def main() -> None: final_tokens=2 * 500000 * context_size * 3, observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), max_timestep=max_timestep, + position_encoding_type=d3rlpy.PositionEncodingType.GLOBAL, ).create(device=args.gpu) n_steps_per_epoch = dataset.transition_count // batch_size diff --git a/tests/models/test_builders.py b/tests/models/test_builders.py index 97923900..31f76e1e 100644 --- a/tests/models/test_builders.py +++ b/tests/models/test_builders.py @@ -4,6 +4,7 @@ import pytest import torch +from d3rlpy.constants import PositionEncodingType from d3rlpy.models.builders import ( create_categorical_policy, create_conditional_vae, @@ -270,7 +271,10 @@ def test_create_parameter(shape: Sequence[int]) -> None: @pytest.mark.parametrize("context_size", [10]) @pytest.mark.parametrize("dropout", [0.1]) @pytest.mark.parametrize("activation_type", ["relu"]) -@pytest.mark.parametrize("position_encoding_type", ["simple"]) +@pytest.mark.parametrize( + "position_encoding_type", + [PositionEncodingType.SIMPLE, PositionEncodingType.GLOBAL], +) @pytest.mark.parametrize("batch_size", [32]) def test_create_continuous_decision_transformer( observation_shape: Sequence[int], @@ -282,7 +286,7 @@ def test_create_continuous_decision_transformer( context_size: int, dropout: float, activation_type: str, - position_encoding_type: str, + position_encoding_type: PositionEncodingType, batch_size: int, ) -> None: transformer = create_continuous_decision_transformer( @@ -321,7 +325,10 @@ def test_create_continuous_decision_transformer( @pytest.mark.parametrize("context_size", [10]) @pytest.mark.parametrize("dropout", [0.1]) @pytest.mark.parametrize("activation_type", ["relu"]) -@pytest.mark.parametrize("position_encoding_type", ["simple"]) +@pytest.mark.parametrize( + "position_encoding_type", + [PositionEncodingType.SIMPLE, PositionEncodingType.GLOBAL], +) @pytest.mark.parametrize("batch_size", [32]) def test_create_discrete_decision_transformer( observation_shape: Sequence[int], @@ -333,7 +340,7 @@ def test_create_discrete_decision_transformer( context_size: int, dropout: float, activation_type: str, - position_encoding_type: str, + position_encoding_type: PositionEncodingType, batch_size: int, ) -> None: transformer = create_discrete_decision_transformer(