Skip to content

Commit

Permalink
Make position_encoding_type Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 9, 2023
1 parent f53fd3b commit c5d50aa
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 38 deletions.
20 changes: 19 additions & 1 deletion d3rlpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,27 @@
)
from ._version import __version__
from .base import load_learnable
from .constants import ActionSpace
from .constants import ActionSpace, PositionEncodingType
from .healthcheck import run_healthcheck

__all__ = [
"algos",
"dataset",
"datasets",
"envs",
"logging",
"metrics",
"models",
"notebook_utils",
"ope",
"preprocessing",
"__version__",
"load_learnable",
"ActionSpace",
"PositionEncodingType",
"seed",
]


def seed(n: int) -> None:
"""Sets random seed value.
Expand Down
14 changes: 7 additions & 7 deletions d3rlpy/algos/transformer/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from ...base import DeviceArg, register_learnable
from ...constants import ActionSpace
from ...constants import ActionSpace, PositionEncodingType
from ...dataset import Shape
from ...models import (
EncoderFactory,
Expand Down Expand Up @@ -61,8 +61,8 @@ class DecisionTransformerConfig(TransformerConfig):
resid_dropout (float): Dropout probability for residual connection.
embed_dropout (float): Dropout probability for embeddings.
activation_type (str): Type of activation function.
position_encoding_type (str): Type of positional encoding
(``simple`` or ``global``).
position_encoding_type (d3rlpy.PositionEncodingType):
Type of positional encoding (``SIMPLE`` or ``GLOBAL``).
warmup_steps (int): Warmup steps for learning rate scheduler.
clip_grad_norm (float): Norm of gradient clipping.
compile (bool): (experimental) Flag to enable JIT compilation.
Expand All @@ -78,7 +78,7 @@ class DecisionTransformerConfig(TransformerConfig):
resid_dropout: float = 0.1
embed_dropout: float = 0.1
activation_type: str = "relu"
position_encoding_type: str = "simple"
position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE
warmup_steps: int = 10000
clip_grad_norm: float = 0.25
compile: bool = False
Expand Down Expand Up @@ -172,8 +172,8 @@ class DiscreteDecisionTransformerConfig(TransformerConfig):
activation_type (str): Type of activation function.
embed_activation_type (str): Type of activation function applied to
embeddings.
position_encoding_type (str): Type of positional encoding
(``simple`` or ``global``).
position_encoding_type (d3rlpy.PositionEncodingType):
Type of positional encoding (``SIMPLE`` or ``GLOBAL``).
warmup_tokens (int): Number of tokens to warmup learning rate scheduler.
final_tokens (int): Final number of tokens for learning rate scheduler.
clip_grad_norm (float): Norm of gradient clipping.
Expand All @@ -191,7 +191,7 @@ class DiscreteDecisionTransformerConfig(TransformerConfig):
embed_dropout: float = 0.1
activation_type: str = "gelu"
embed_activation_type: str = "tanh"
position_encoding_type: str = "global"
position_encoding_type: PositionEncodingType = PositionEncodingType.GLOBAL
warmup_tokens: int = 10240
final_tokens: int = 30000000
clip_grad_norm: float = 1.0
Expand Down
6 changes: 6 additions & 0 deletions d3rlpy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"DISCRETE_ACTION_SPACE_MISMATCH_ERROR",
"CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR",
"ActionSpace",
"PositionEncodingType",
]

IMPL_NOT_INITIALIZED_ERROR = (
Expand Down Expand Up @@ -41,3 +42,8 @@ class ActionSpace(Enum):
CONTINUOUS = 1
DISCRETE = 2
BOTH = 3


class PositionEncodingType(Enum):
SIMPLE = "simple"
GLOBAL = "global"
3 changes: 3 additions & 0 deletions d3rlpy/logging/file_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from enum import Enum, IntEnum
from typing import Any, Dict

import numpy as np
Expand All @@ -17,6 +18,8 @@ def default_json_encoder(obj: Any) -> Any:
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (Enum, IntEnum)):
return obj.value
raise ValueError(f"invalid object type: {type(obj)}")


Expand Down
61 changes: 35 additions & 26 deletions d3rlpy/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import nn

from ..constants import PositionEncodingType
from ..dataset import Shape
from .encoders import EncoderFactory
from .q_functions import QFunctionFactory
Expand All @@ -18,6 +19,7 @@
GlobalPositionEncoding,
NormalPolicy,
Parameter,
PositionEncoding,
SimplePositionEncoding,
VAEDecoder,
VAEEncoder,
Expand Down Expand Up @@ -247,6 +249,25 @@ def create_parameter(
return parameter


def _create_position_encoding(
position_encoding_type: PositionEncodingType,
embed_dim: int,
max_timestep: int,
context_size: int,
) -> PositionEncoding:
if position_encoding_type == PositionEncodingType.SIMPLE:
position_encoding = SimplePositionEncoding(embed_dim, max_timestep + 1)
elif position_encoding_type == PositionEncodingType.GLOBAL:
position_encoding = GlobalPositionEncoding(
embed_dim, max_timestep + 1, context_size
)
else:
raise ValueError(
f"invalid position_encoding_type: {position_encoding_type}"
)
return position_encoding


def create_continuous_decision_transformer(
observation_shape: Shape,
action_size: int,
Expand All @@ -259,24 +280,18 @@ def create_continuous_decision_transformer(
resid_dropout: float,
embed_dropout: float,
activation_type: str,
position_encoding_type: str,
position_encoding_type: PositionEncodingType,
device: str,
) -> ContinuousDecisionTransformer:
encoder = encoder_factory.create(observation_shape)
hidden_size = compute_output_size([observation_shape], encoder)

if position_encoding_type == "simple":
position_encoding = SimplePositionEncoding(
hidden_size, max_timestep + 1
)
elif position_encoding_type == "global":
position_encoding = GlobalPositionEncoding(
hidden_size, max_timestep + 1, context_size
)
else:
raise ValueError(
f"invalid position_encoding_type: {position_encoding_type}"
)
position_encoding = _create_position_encoding(
position_encoding_type=position_encoding_type,
embed_dim=hidden_size,
max_timestep=max_timestep + 1,
context_size=context_size,
)

transformer = ContinuousDecisionTransformer(
encoder=encoder,
Expand Down Expand Up @@ -308,24 +323,18 @@ def create_discrete_decision_transformer(
embed_dropout: float,
activation_type: str,
embed_activation_type: str,
position_encoding_type: str,
position_encoding_type: PositionEncodingType,
device: str,
) -> DiscreteDecisionTransformer:
encoder = encoder_factory.create(observation_shape)
hidden_size = compute_output_size([observation_shape], encoder)

if position_encoding_type == "simple":
position_encoding = SimplePositionEncoding(
hidden_size, max_timestep + 1
)
elif position_encoding_type == "global":
position_encoding = GlobalPositionEncoding(
hidden_size, max_timestep + 1, context_size
)
else:
raise ValueError(
f"invalid position_encoding_type: {position_encoding_type}"
)
position_encoding = _create_position_encoding(
position_encoding_type=position_encoding_type,
embed_dim=hidden_size,
max_timestep=max_timestep + 1,
context_size=context_size,
)

transformer = DiscreteDecisionTransformer(
encoder=encoder,
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/models/torch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
__all__ = [
"ContinuousDecisionTransformer",
"DiscreteDecisionTransformer",
"PositionEncoding",
"SimplePositionEncoding",
"GlobalPositionEncoding",
]
Expand Down
1 change: 1 addition & 0 deletions reproductions/offline/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def main() -> None:
),
observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
reward_scaler=d3rlpy.preprocessing.MultiplyRewardScaler(0.001),
position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE,
context_size=20,
num_heads=1,
num_layers=3,
Expand Down
1 change: 1 addition & 0 deletions reproductions/offline/discrete_decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def main() -> None:
final_tokens=2 * 500000 * context_size * 3,
observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
max_timestep=max_timestep,
position_encoding_type=d3rlpy.PositionEncodingType.GLOBAL,
).create(device=args.gpu)

n_steps_per_epoch = dataset.transition_count // batch_size
Expand Down
15 changes: 11 additions & 4 deletions tests/models/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch

from d3rlpy.constants import PositionEncodingType
from d3rlpy.models.builders import (
create_categorical_policy,
create_conditional_vae,
Expand Down Expand Up @@ -270,7 +271,10 @@ def test_create_parameter(shape: Sequence[int]) -> None:
@pytest.mark.parametrize("context_size", [10])
@pytest.mark.parametrize("dropout", [0.1])
@pytest.mark.parametrize("activation_type", ["relu"])
@pytest.mark.parametrize("position_encoding_type", ["simple"])
@pytest.mark.parametrize(
"position_encoding_type",
[PositionEncodingType.SIMPLE, PositionEncodingType.GLOBAL],
)
@pytest.mark.parametrize("batch_size", [32])
def test_create_continuous_decision_transformer(
observation_shape: Sequence[int],
Expand All @@ -282,7 +286,7 @@ def test_create_continuous_decision_transformer(
context_size: int,
dropout: float,
activation_type: str,
position_encoding_type: str,
position_encoding_type: PositionEncodingType,
batch_size: int,
) -> None:
transformer = create_continuous_decision_transformer(
Expand Down Expand Up @@ -321,7 +325,10 @@ def test_create_continuous_decision_transformer(
@pytest.mark.parametrize("context_size", [10])
@pytest.mark.parametrize("dropout", [0.1])
@pytest.mark.parametrize("activation_type", ["relu"])
@pytest.mark.parametrize("position_encoding_type", ["simple"])
@pytest.mark.parametrize(
"position_encoding_type",
[PositionEncodingType.SIMPLE, PositionEncodingType.GLOBAL],
)
@pytest.mark.parametrize("batch_size", [32])
def test_create_discrete_decision_transformer(
observation_shape: Sequence[int],
Expand All @@ -333,7 +340,7 @@ def test_create_discrete_decision_transformer(
context_size: int,
dropout: float,
activation_type: str,
position_encoding_type: str,
position_encoding_type: PositionEncodingType,
batch_size: int,
) -> None:
transformer = create_discrete_decision_transformer(
Expand Down

0 comments on commit c5d50aa

Please sign in to comment.