Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DroQ #342

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .crr import *
from .ddpg import *
from .dqn import *
from .droq import *
from .explorers import *
from .iql import *
from .nfq import *
Expand Down
29 changes: 16 additions & 13 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def fit_online(
n_steps_per_epoch: int = 10000,
update_interval: int = 1,
update_start_step: int = 0,
utd_ratio: int = 1,
random_steps: int = 0,
eval_env: Optional[GymEnv] = None,
eval_epsilon: float = 0.0,
Expand All @@ -594,6 +595,7 @@ def fit_online(
n_steps_per_epoch: Number of steps per epoch.
update_interval: Number of steps per update.
update_start_step: Steps before starting updates.
utd_ratio: UTD (update-to-data) ration, the number of updates taken by the agent compared to the number of actual interactions with the environment
random_steps: Steps for the initial random explortion.
eval_env: Gym-like environment. If None, evaluation is skipped.
eval_epsilon: :math:`\\epsilon`-greedy factor during evaluation.
Expand Down Expand Up @@ -691,19 +693,20 @@ def fit_online(
and buffer.transition_count > self.batch_size
):
if total_step % update_interval == 0:
# sample mini-batch
with logger.measure_time("sample_batch"):
batch = buffer.sample_transition_batch(
self.batch_size
)

# update parameters
with logger.measure_time("algorithm_update"):
loss = self.update(batch)

# record metrics
for name, val in loss.items():
logger.add_metric(name, val)
for _ in range(utd_ratio):
# sample mini-batch
with logger.measure_time("sample_batch"):
batch = buffer.sample_transition_batch(
self.batch_size
)

# update parameters
with logger.measure_time("algorithm_update"):
loss = self.update(batch)

# record metrics
for name, val in loss.items():
logger.add_metric(name, val)

# call callback if given
if callback:
Expand Down
121 changes: 121 additions & 0 deletions d3rlpy/algos/qlearning/droq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import dataclasses
import math

from .torch import SACModules
from .torch.droq_impl import DroQImpl
from ...base import DeviceArg, register_learnable, LearnableConfig
from ...constants import ActionSpace
from ...dataset import Shape
from ...models import QFunctionFactory, make_q_func_field, make_optimizer_field, OptimizerFactory
from ...models.builders import (
create_continuous_q_function,
create_normal_policy,
create_parameter,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from .base import QLearningAlgoBase


__all__ = ["DroQConfig", "DroQ"]


@dataclasses.dataclass()
class DroQConfig(LearnableConfig):
r"""TODO
"""
actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
temp_learning_rate: float = 3e-4
actor_optim_factory: OptimizerFactory = make_optimizer_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()
temp_optim_factory: OptimizerFactory = make_optimizer_field()
actor_encoder_factory: EncoderFactory = make_encoder_field()
critic_encoder_factory: EncoderFactory = make_encoder_field()
q_func_factory: QFunctionFactory = make_q_func_field()
batch_size: int = 256
gamma: float = 0.99
tau: float = 0.005
n_critics: int = 2
initial_temperature: float = 1.0

def create(self, device: DeviceArg = False) -> "DroQ":
return DroQ(self, device)

Check warning on line 42 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L42

Added line #L42 was not covered by tests

@staticmethod
def get_type() -> str:
return "droq"


class DroQ(QLearningAlgoBase[DroQImpl, DroQConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
policy = create_normal_policy(

Check warning on line 53 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L53

Added line #L53 was not covered by tests
observation_shape,
action_size,
self._config.actor_encoder_factory,
device=self._device,
)
q_funcs, q_func_forwarder = create_continuous_q_function(

Check warning on line 59 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L59

Added line #L59 was not covered by tests
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(

Check warning on line 67 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L67

Added line #L67 was not covered by tests
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
log_temp = create_parameter(

Check warning on line 75 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L75

Added line #L75 was not covered by tests
(1, 1),
math.log(self._config.initial_temperature),
device=self._device,
)

actor_optim = self._config.actor_optim_factory.create(

Check warning on line 81 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L81

Added line #L81 was not covered by tests
policy.parameters(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(

Check warning on line 84 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L84

Added line #L84 was not covered by tests
q_funcs.parameters(), lr=self._config.critic_learning_rate
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(

Check warning on line 88 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L87-L88

Added lines #L87 - L88 were not covered by tests
log_temp.parameters(), lr=self._config.temp_learning_rate
)
else:
temp_optim = None

Check warning on line 92 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L92

Added line #L92 was not covered by tests

modules = SACModules(

Check warning on line 94 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L94

Added line #L94 was not covered by tests
policy=policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
log_temp=log_temp,
actor_optim=actor_optim,
critic_optim=critic_optim,
temp_optim=temp_optim,
)

self._impl = DroQImpl(

Check warning on line 104 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L104

Added line #L104 was not covered by tests
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
device=self._device,
)

def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS

Check warning on line 116 in d3rlpy/algos/qlearning/droq.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/droq.py#L116

Added line #L116 was not covered by tests


# (TODO IF VALID) class DiscreteDroQConfig(LearnableConfig):

register_learnable(DroQConfig)
24 changes: 24 additions & 0 deletions d3rlpy/algos/qlearning/torch/droq_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch

from . import SACImpl
from ....models.torch import build_squashed_gaussian_distribution
from ....torch_utility import TorchMiniBatch

__all__ = ["DroQImpl"]


class DroQImpl(SACImpl):
def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
dist = build_squashed_gaussian_distribution(

Check warning on line 12 in d3rlpy/algos/qlearning/torch/droq_impl.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/torch/droq_impl.py#L12

Added line #L12 was not covered by tests
self._modules.policy(batch.observations)
)
action, log_prob = dist.sample_with_log_prob()
entropy = self._modules.log_temp().exp() * log_prob
q_t = self._q_func_forwarder.compute_expected_q(

Check warning on line 17 in d3rlpy/algos/qlearning/torch/droq_impl.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/torch/droq_impl.py#L15-L17

Added lines #L15 - L17 were not covered by tests
# Use "mean" (line 10 of Algorithm 2 in the paper)
batch.observations, action, "mean"
)
return (entropy - q_t).mean()

Check warning on line 21 in d3rlpy/algos/qlearning/torch/droq_impl.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/algos/qlearning/torch/droq_impl.py#L21

Added line #L21 was not covered by tests


# (TODO IF VALID) class DiscreteDroQImpl
4 changes: 4 additions & 0 deletions d3rlpy/models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class VectorEncoderFactory(EncoderFactory):
activation: str = "relu"
use_batch_norm: bool = False
dropout_rate: Optional[float] = None
use_layer_norm: bool = False
exclude_last_activation: bool = False

def create(self, observation_shape: Shape) -> VectorEncoder:
Expand All @@ -147,6 +148,7 @@ def create(self, observation_shape: Shape) -> VectorEncoder:
hidden_units=self.hidden_units,
use_batch_norm=self.use_batch_norm,
dropout_rate=self.dropout_rate,
# use_layer_norm=self.use_layer_norm,
activation=create_activation(self.activation),
exclude_last_activation=self.exclude_last_activation,
)
Expand All @@ -164,6 +166,7 @@ def create_with_action(
hidden_units=self.hidden_units,
use_batch_norm=self.use_batch_norm,
dropout_rate=self.dropout_rate,
use_layer_norm=self.use_layer_norm,
discrete_action=discrete_action,
activation=create_activation(self.activation),
exclude_last_activation=self.exclude_last_activation,
Expand All @@ -189,6 +192,7 @@ class DefaultEncoderFactory(EncoderFactory):
activation: str = "relu"
use_batch_norm: bool = False
dropout_rate: Optional[float] = None
use_layer_norm: bool = False

def create(self, observation_shape: Shape) -> Encoder:
factory: Union[PixelEncoderFactory, VectorEncoderFactory]
Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/models/torch/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@
hidden_units: Optional[Sequence[int]] = None,
use_batch_norm: bool = False,
dropout_rate: Optional[float] = None,
use_layer_norm: bool = False,
discrete_action: bool = False,
activation: nn.Module = nn.ReLU(),
exclude_last_activation: bool = False,
Expand All @@ -251,6 +252,8 @@
layers.append(nn.BatchNorm1d(out_unit))
if dropout_rate is not None:
layers.append(nn.Dropout(dropout_rate))
if use_layer_norm:
layers.append(nn.LayerNorm(out_unit))

Check warning on line 256 in d3rlpy/models/torch/encoders.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/models/torch/encoders.py#L256

Added line #L256 was not covered by tests
self._layers = nn.Sequential(*layers)

def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
Expand Down
Loading