Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Mamba413 committed Jul 22, 2024
1 parent 2520211 commit f4cd433
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 23 deletions.
86 changes: 80 additions & 6 deletions d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@
from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...models.builders import (
create_categorical_policy,
create_continuous_q_function,
create_discrete_q_function,
create_categorical_policy,
create_normal_policy,
create_value_function,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import MeanQFunctionFactory
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...models.q_functions import (
MeanQFunctionFactory,
QFunctionFactory,
make_q_func_field,
)
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.iql_impl import IQLImpl, IQLModules, DiscreteIQLImpl, DiscreteIQLModules
from .torch.iql_impl import (
DiscreteIQLImpl,
DiscreteIQLModules,
IQLImpl,
IQLModules,
)

__all__ = ["IQLConfig", "IQL", "DiscreteIQLConfig", "DiscreteIQL"]

Expand Down Expand Up @@ -178,19 +186,83 @@ def inner_create_impl(
def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS


@dataclasses.dataclass()
class DiscreteIQLConfig(LearnableConfig):
r"""Implicit Q-Learning algorithm.
IQL is the offline RL algorithm that avoids ever querying values of unseen
actions while still being able to perform multi-step dynamic programming
updates.
There are three functions to train in IQL. First the state-value function
is trained via expectile regression.
.. math::
L_V(\psi) = \mathbb{E}_{(s, a) \sim D}
[L_2^\tau (Q_\theta (s, a) - V_\psi (s))]
where :math:`L_2^\tau (u) = |\tau - \mathbb{1}(u < 0)|u^2`.
The Q-function is trained with the state-value function to avoid query the
actions.
.. math::
L_Q(\theta) = \mathbb{E}_{(s, a, r, s') \sim D}
[(r + \gamma V_\psi(s') - Q_\theta(s, a))^2]
Finally, the policy function is trained by using advantage weighted
regression compared with `IQL`, here we use a categorical policy.
.. math::
L_\pi (\phi) = \mathbb{E}_{(s, a) \sim D}
[\exp(\beta (Q_\theta - V_\psi(s))) \log \pi_\phi(a|s)]
References:
* `Kostrikov et al., Offline Reinforcement Learning with Implicit
Q-Learning. <https://arxiv.org/abs/2110.06169>`_
Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
actor_learning_rate (float): Learning rate for policy function.
critic_learning_rate (float): Learning rate for Q functions.
actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the actor.
critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the critic.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the critic.
value_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the value function.
batch_size (int): Mini-batch size.
gamma (float): Discount factor.
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
expectile (float): Expectile value for value function training.
weight_temp (float): Inverse temperature value represented as
:math:`\beta`.
max_weight (float): Maximum advantage weight value to clip.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4

q_func_factory: QFunctionFactory = make_q_func_field()
encoder_factory: EncoderFactory = make_encoder_field()
value_encoder_factory: EncoderFactory = make_encoder_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()

actor_encoder_factory: EncoderFactory = make_encoder_field()
actor_optim_factory: OptimizerFactory = make_optimizer_field()

batch_size: int = 256
gamma: float = 0.99
tau: float = 0.005
Expand All @@ -206,6 +278,7 @@ def create(self, device: DeviceArg = False) -> "DiscreteIQL":
def get_type() -> str:
return "discrete_iql"


class DiscreteIQL(QLearningAlgoBase[DiscreteIQLImpl, DiscreteIQLConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
Expand Down Expand Up @@ -273,5 +346,6 @@ def inner_create_impl(
def get_action_type(self) -> ActionSpace:
return ActionSpace.DISCRETE


register_learnable(IQLConfig)
register_learnable(DiscreteIQLConfig)
18 changes: 10 additions & 8 deletions d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from ....models.torch import (
ActionOutput,
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
Policy,
)
from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync
from ....types import Shape, TorchObservation
from ..base import QLearningAlgoImplBase
from .utility import ContinuousQFunctionMixin
from .utility import ContinuousQFunctionMixin, DiscreteQFunctionMixin

__all__ = [
"DDPGImpl",
Expand Down Expand Up @@ -158,21 +159,21 @@ def q_function_optim(self) -> Optimizer:


class DiscreteDDPGBaseImpl(
ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta
DiscreteQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta
):
_modules: DDPGBaseModules
_gamma: float
_tau: float
_q_func_forwarder: ContinuousEnsembleQFunctionForwarder
_targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder
_q_func_forwarder: DiscreteEnsembleQFunctionForwarder
_targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DDPGBaseModules,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
gamma: float,
tau: float,
device: str,
Expand Down Expand Up @@ -216,7 +217,7 @@ def update_actor(
# Q function should be inference mode for stability
self._modules.q_funcs.eval()
self._modules.actor_optim.zero_grad()
loss = self.compute_actor_loss(batch, action)
loss = self.compute_actor_loss(batch, None)
loss.actor_loss.backward()
self._modules.actor_optim.step()
return asdict_as_float(loss)
Expand All @@ -233,7 +234,7 @@ def inner_update(

@abstractmethod
def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
self, batch: TorchMiniBatch, action: None
) -> DDPGBaseActorLoss:
pass

Expand Down Expand Up @@ -267,6 +268,7 @@ def q_function(self) -> nn.ModuleList:
def q_function_optim(self) -> Optimizer:
return self._modules.critic_optim


@dataclasses.dataclass(frozen=True)
class DDPGModules(DDPGBaseModules):
targ_policy: Policy
Expand Down
23 changes: 15 additions & 8 deletions d3rlpy/algos/qlearning/torch/iql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from ....models.torch import (
ActionOutput,
CategoricalPolicy,
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
NormalPolicy,
CategoricalPolicy,
ValueFunction,
build_gaussian_distribution,
)
Expand All @@ -18,12 +18,13 @@
DDPGBaseActorLoss,
DDPGBaseCriticLoss,
DDPGBaseImpl,
DiscreteDDPGBaseImpl,
DDPGBaseModules,
DiscreteDDPGBaseImpl,
)

__all__ = ["IQLImpl", "IQLModules", "DiscreteIQLImpl", "DiscreteIQLModules"]


@dataclasses.dataclass(frozen=True)
class IQLModules(DDPGBaseModules):
policy: NormalPolicy
Expand Down Expand Up @@ -193,16 +194,20 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
with torch.no_grad():
return self._modules.value_func(batch.next_observations)

def compute_actor_loss(self, batch: TorchMiniBatch, action) -> DDPGBaseActorLoss:
def compute_actor_loss(
self, batch: TorchMiniBatch, action: None
) -> DDPGBaseActorLoss:
assert self._modules.policy
# compute weight
with torch.no_grad():
v = self._modules.value_func(batch.observations)
min_Q = self._targ_q_func_forwarder.compute_target(batch.observations, reduction="min").gather(
1, batch.actions.long()
)
min_Q = self._targ_q_func_forwarder.compute_target(
batch.observations, reduction="min"
).gather(1, batch.actions.long())

exp_a = torch.exp((min_Q - v) * self._weight_temp).clamp(max=self._max_weight)
exp_a = torch.exp((min_Q - v) * self._weight_temp).clamp(
max=self._max_weight
)
# compute log probability
dist = self._modules.policy(batch.observations)
log_probs = dist.log_prob(batch.actions.squeeze(-1)).unsqueeze(1)
Expand All @@ -211,7 +216,9 @@ def compute_actor_loss(self, batch: TorchMiniBatch, action) -> DDPGBaseActorLoss

def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
q_t = self._targ_q_func_forwarder.compute_expected_q(batch.observations)
one_hot = F.one_hot(batch.actions.long().view(-1), num_classes=self.action_size)
one_hot = F.one_hot(
batch.actions.long().view(-1), num_classes=self.action_size
)
q_t = (q_t * one_hot).sum(dim=1, keepdim=True)

v_t = self._modules.value_func(batch.observations)
Expand Down
28 changes: 27 additions & 1 deletion tests/algos/qlearning/test_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import pytest

from d3rlpy.algos.qlearning.iql import IQLConfig
from d3rlpy.algos.qlearning.iql import DiscreteIQLConfig, IQLConfig
from d3rlpy.types import Shape
from d3rlpy.models import (
QFunctionFactory,
)

from ...models.torch.model_test import DummyEncoderFactory
from ...testing_utils import create_scaler_tuple
Expand All @@ -28,3 +31,26 @@ def test_iql(observation_shape: Shape, scalers: Optional[str]) -> None:
)
iql = config.create()
algo_tester(iql, observation_shape) # type: ignore


@pytest.mark.parametrize(
"observation_shape", [(100,), (4, 84, 84), ((100,), (200,))]
)
@pytest.mark.parametrize("scalers", [None, "min_max"])
def test_discrete_iql(
observation_shape: Shape,
q_func_factory: QFunctionFactory,
scalers: Optional[str]) -> None:
observation_scaler, _, reward_scaler = create_scaler_tuple(
scalers, observation_shape
)
config = DiscreteIQLConfig(
actor_encoder_factory=DummyEncoderFactory(),
encoder_factory=DummyEncoderFactory(),
value_encoder_factory=DummyEncoderFactory(),
q_func_factory=q_func_factory,
observation_scaler=observation_scaler,
reward_scaler=reward_scaler,
)
iql = config.create()
algo_tester(iql, observation_shape) # type: ignore

0 comments on commit f4cd433

Please sign in to comment.