Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PRDC #431

Merged
merged 4 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ build
dist
/.idea/
*.egg-info
/.vscode/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ $ docker run -it --gpus all --name d3rlpy takuseno/d3rlpy:latest bash
| [Critic Reguralized Regression (CRR)](https://arxiv.org/abs/2006.15134) | :no_entry: | :white_check_mark: |
| [Policy in Latent Action Space (PLAS)](https://arxiv.org/abs/2011.07213) | :no_entry: | :white_check_mark: |
| [TD3+BC](https://arxiv.org/abs/2106.06860) | :no_entry: | :white_check_mark: |
| [Policy Regularization with Dataset Constraint (PRDC)](https://arxiv.org/abs/2306.06569) | :no_entry: | :white_check_mark: |
| [Implicit Q-Learning (IQL)](https://arxiv.org/abs/2110.06169) | :no_entry: | :white_check_mark: |
| [Calibrated Q-Learning (Cal-QL)](https://arxiv.org/abs/2303.05479) | :no_entry: | :white_check_mark: |
| [ReBRAC](https://arxiv.org/abs/2305.09836) | :no_entry: | :white_check_mark: |
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .iql import *
from .nfq import *
from .plas import *
from .prdc import *
from .random_policy import *
from .rebrac import *
from .sac import *
Expand Down
259 changes: 259 additions & 0 deletions d3rlpy/algos/qlearning/prdc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import dataclasses
from typing import Callable, Optional

import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from typing_extensions import Self

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace, LoggingStrategy
from ...dataset import ReplayBufferBase
from ...logging import FileAdapterFactory, LoggerAdapterFactory
from ...metrics import EvaluatorProtocol
from ...models.builders import (
create_continuous_q_function,
create_deterministic_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from ..utility import build_scalers_with_transition_picker
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGModules
from .torch.prdc_impl import PRDCImpl

__all__ = ["PRDCConfig", "PRDC"]


@dataclasses.dataclass()
class PRDCConfig(LearnableConfig):
r"""Config of PRDC algorithm.

PRDC is an simple offline RL algorithm built on top of TD3.
PRDC introduces Dataset Constraint (DC)-reguralized policy objective function.

.. math::

J(\phi) = \mathbb{E}_{s \sim D}
[\lambda Q(s, \pi(s)) - d^\beta_D(s, \pi(s))]

where

.. math::

\lambda = \frac{\alpha}{\frac{1}{N} \sum_(s_i, a_i) |Q(s_i, a_i)|}

and `d^\beta_\mathcal{D}(s,\pi(s))` is the DC loss, defined as

.. math::

d^\beta_\mathcal{D}(s,\pi(s)) = \min_{\hat{s}, \hat{a} \sim D}
[\| (\beta s) \oplus \pi(s) - (\beta \hat{s}) \oplus \hat{a} \|]

References:
* `Ran et al., Policy Regularization with Dataset Constraint for Offline Reinforcement Learning
Learning. <https://arxiv.org/abs/2306.06569>`_

Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
actor_learning_rate (float): Learning rate for a policy function.
critic_learning_rate (float): Learning rate for Q functions.
actor_optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory for the actor.
critic_optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory for the critic.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the critic.
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory):
Q function factory.
batch_size (int): Mini-batch size.
gamma (float): Discount factor.
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
target_smoothing_sigma (float): Standard deviation for target noise.
target_smoothing_clip (float): Clipping range for target noise.
alpha (float): :math:`\alpha` value.
beta (float): :math:`\beta` value.
update_actor_interval (int): Interval to update policy function
described as `delayed policy update` in the paper.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
actor_optim_factory: OptimizerFactory = make_optimizer_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()
actor_encoder_factory: EncoderFactory = make_encoder_field()
critic_encoder_factory: EncoderFactory = make_encoder_field()
q_func_factory: QFunctionFactory = make_q_func_field()
batch_size: int = 256
gamma: float = 0.99
tau: float = 0.005
n_critics: int = 2
target_smoothing_sigma: float = 0.2
target_smoothing_clip: float = 0.5
alpha: float = 2.5
beta: float = 2.0
update_actor_interval: int = 2

def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "PRDC":
return PRDC(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
return "prdc"


class PRDC(QLearningAlgoBase[PRDCImpl, PRDCConfig]):
_nbsr = NearestNeighbors(n_neighbors=1, algorithm="auto", n_jobs=-1)

def inner_create_impl(self, observation_shape: Shape, action_size: int) -> None:
policy = create_deterministic_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_policy = create_deterministic_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)

modules = DDPGModules(
policy=policy,
targ_policy=targ_policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
actor_optim=actor_optim,
critic_optim=critic_optim,
)

self._impl = PRDCImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
alpha=self._config.alpha,
beta=self._config.beta,
update_actor_interval=self._config.update_actor_interval,
compiled=self.compiled,
nbsr=self._nbsr,
device=self._device,
)

def fit(
self,
dataset: ReplayBufferBase,
n_steps: int,
n_steps_per_epoch: int = 10000,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> list[tuple[int, dict[str, float]]]:
observations = []
actions = []
for episode in dataset.buffer.episodes:
for i in range(episode.transition_count):
transition = dataset.transition_picker(episode, i)
observations.append(np.reshape(transition.observation, (1, -1)))
actions.append(np.reshape(transition.action, (1, -1)))
observations = np.concatenate(observations, axis=0)
actions = np.concatenate(actions, axis=0)

build_scalers_with_transition_picker(self, dataset)
if self.observation_scaler and self.observation_scaler.built:
observations = (
self.observation_scaler.transform(
torch.tensor(observations, device=self._device)
)
.cpu()
.numpy()
)

if self.action_scaler and self.action_scaler.built:
actions = (
self.action_scaler.transform(torch.tensor(actions, device=self._device))
.cpu()
.numpy()
)

self._nbsr.fit(
np.concatenate(
[np.multiply(observations, self._config.beta), actions],
axis=1,
)
)

return super().fit(
dataset=dataset,
n_steps=n_steps,
n_steps_per_epoch=n_steps_per_epoch,
logging_steps=logging_steps,
logging_strategy=logging_strategy,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logger_adapter=logger_adapter,
show_progress=show_progress,
save_interval=save_interval,
evaluators=evaluators,
callback=callback,
epoch_callback=epoch_callback,
)

def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS


register_learnable(PRDCConfig)
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .dqn_impl import *
from .iql_impl import *
from .plas_impl import *
from .prdc_impl import *
from .rebrac_impl import *
from .sac_impl import *
from .td3_impl import *
Expand Down
84 changes: 84 additions & 0 deletions d3rlpy/algos/qlearning/torch/prdc_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# pylint: disable=too-many-ancestors
import dataclasses

import torch
from sklearn.neighbors import NearestNeighbors

from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .ddpg_impl import DDPGBaseActorLoss, DDPGModules
from .td3_impl import TD3Impl

__all__ = ["PRDCImpl"]


@dataclasses.dataclass(frozen=True)
class PRDCActorLoss(DDPGBaseActorLoss):
dc_loss: torch.Tensor


class PRDCImpl(TD3Impl):
_alpha: float
_beta: float
_nbsr: NearestNeighbors

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DDPGModules,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
gamma: float,
tau: float,
target_smoothing_sigma: float,
target_smoothing_clip: float,
alpha: float,
beta: float,
update_actor_interval: int,
compiled: bool,
nbsr: NearestNeighbors,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
target_smoothing_sigma=target_smoothing_sigma,
target_smoothing_clip=target_smoothing_clip,
update_actor_interval=update_actor_interval,
compiled=compiled,
device=device,
)
self._alpha = alpha
self._beta = beta
self._nbsr = nbsr

def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> PRDCActorLoss:
q_t = self._q_func_forwarder.compute_expected_q(
batch.observations, action.squashed_mu, "none"
)[0]
lam = self._alpha / (q_t.abs().mean()).detach()
key = (
torch.cat(
[torch.mul(batch.observations, self._beta), action.squashed_mu], dim=-1
)
.detach()
.cpu()
.numpy()
)
idx = self._nbsr.kneighbors(key, n_neighbors=1, return_distance=False)
nearest_neightbour = torch.tensor(
self._nbsr._fit_X[idx][:, :, -self.action_size :],
device=self.device,
dtype=action.squashed_mu.dtype,
).squeeze(dim=1)
dc_loss = torch.nn.functional.mse_loss(action.squashed_mu, nearest_neightbour)
return PRDCActorLoss(actor_loss=lam * -q_t.mean() + dc_loss, dc_loss=dc_loss)
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ follow_imports_for_stubs = True
ignore_missing_imports = True
follow_imports = skip
follow_imports_for_stubs = True

[mypy-sklearn.*]
ignore_missing_imports = True
Loading
Loading