diff --git a/.gitignore b/.gitignore index 348562a3..049c3f06 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ build dist /.idea/ *.egg-info +/.vscode/ diff --git a/README.md b/README.md index 951e405a..ce8a1a8e 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ $ docker run -it --gpus all --name d3rlpy takuseno/d3rlpy:latest bash | [Critic Reguralized Regression (CRR)](https://arxiv.org/abs/2006.15134) | :no_entry: | :white_check_mark: | | [Policy in Latent Action Space (PLAS)](https://arxiv.org/abs/2011.07213) | :no_entry: | :white_check_mark: | | [TD3+BC](https://arxiv.org/abs/2106.06860) | :no_entry: | :white_check_mark: | +| [Policy Regularization with Dataset Constraint (PRDC)](https://arxiv.org/abs/2306.06569) | :no_entry: | :white_check_mark: | | [Implicit Q-Learning (IQL)](https://arxiv.org/abs/2110.06169) | :no_entry: | :white_check_mark: | | [Calibrated Q-Learning (Cal-QL)](https://arxiv.org/abs/2303.05479) | :no_entry: | :white_check_mark: | | [ReBRAC](https://arxiv.org/abs/2305.09836) | :no_entry: | :white_check_mark: | diff --git a/d3rlpy/algos/qlearning/__init__.py b/d3rlpy/algos/qlearning/__init__.py index 82c72fe1..081ec4e9 100644 --- a/d3rlpy/algos/qlearning/__init__.py +++ b/d3rlpy/algos/qlearning/__init__.py @@ -12,6 +12,7 @@ from .iql import * from .nfq import * from .plas import * +from .prdc import * from .random_policy import * from .rebrac import * from .sac import * diff --git a/d3rlpy/algos/qlearning/prdc.py b/d3rlpy/algos/qlearning/prdc.py new file mode 100644 index 00000000..2ac01a14 --- /dev/null +++ b/d3rlpy/algos/qlearning/prdc.py @@ -0,0 +1,259 @@ +import dataclasses +from typing import Callable, Optional + +import numpy as np +import torch +from sklearn.neighbors import NearestNeighbors +from typing_extensions import Self + +from ...base import DeviceArg, LearnableConfig, register_learnable +from ...constants import ActionSpace, LoggingStrategy +from ...dataset import ReplayBufferBase +from ...logging import FileAdapterFactory, LoggerAdapterFactory +from ...metrics import EvaluatorProtocol +from ...models.builders import ( + create_continuous_q_function, + create_deterministic_policy, +) +from ...models.encoders import EncoderFactory, make_encoder_field +from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field +from ...types import Shape +from ..utility import build_scalers_with_transition_picker +from .base import QLearningAlgoBase +from .torch.ddpg_impl import DDPGModules +from .torch.prdc_impl import PRDCImpl + +__all__ = ["PRDCConfig", "PRDC"] + + +@dataclasses.dataclass() +class PRDCConfig(LearnableConfig): + r"""Config of PRDC algorithm. + + PRDC is an simple offline RL algorithm built on top of TD3. + PRDC introduces Dataset Constraint (DC)-reguralized policy objective function. + + .. math:: + + J(\phi) = \mathbb{E}_{s \sim D} + [\lambda Q(s, \pi(s)) - d^\beta_D(s, \pi(s))] + + where + + .. math:: + + \lambda = \frac{\alpha}{\frac{1}{N} \sum_(s_i, a_i) |Q(s_i, a_i)|} + + and `d^\beta_\mathcal{D}(s,\pi(s))` is the DC loss, defined as + + .. math:: + + d^\beta_\mathcal{D}(s,\pi(s)) = \min_{\hat{s}, \hat{a} \sim D} + [\| (\beta s) \oplus \pi(s) - (\beta \hat{s}) \oplus \hat{a} \|] + + References: + * `Ran et al., Policy Regularization with Dataset Constraint for Offline Reinforcement Learning + Learning. `_ + + Args: + observation_scaler (d3rlpy.preprocessing.ObservationScaler): + Observation preprocessor. + action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. + reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. + actor_learning_rate (float): Learning rate for a policy function. + critic_learning_rate (float): Learning rate for Q functions. + actor_optim_factory (d3rlpy.optimizers.OptimizerFactory): + Optimizer factory for the actor. + critic_optim_factory (d3rlpy.optimizers.OptimizerFactory): + Optimizer factory for the critic. + actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the actor. + critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the critic. + q_func_factory (d3rlpy.models.q_functions.QFunctionFactory): + Q function factory. + batch_size (int): Mini-batch size. + gamma (float): Discount factor. + tau (float): Target network synchronization coefficiency. + n_critics (int): Number of Q functions for ensemble. + target_smoothing_sigma (float): Standard deviation for target noise. + target_smoothing_clip (float): Clipping range for target noise. + alpha (float): :math:`\alpha` value. + beta (float): :math:`\beta` value. + update_actor_interval (int): Interval to update policy function + described as `delayed policy update` in the paper. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. + """ + + actor_learning_rate: float = 3e-4 + critic_learning_rate: float = 3e-4 + actor_optim_factory: OptimizerFactory = make_optimizer_field() + critic_optim_factory: OptimizerFactory = make_optimizer_field() + actor_encoder_factory: EncoderFactory = make_encoder_field() + critic_encoder_factory: EncoderFactory = make_encoder_field() + q_func_factory: QFunctionFactory = make_q_func_field() + batch_size: int = 256 + gamma: float = 0.99 + tau: float = 0.005 + n_critics: int = 2 + target_smoothing_sigma: float = 0.2 + target_smoothing_clip: float = 0.5 + alpha: float = 2.5 + beta: float = 2.0 + update_actor_interval: int = 2 + + def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "PRDC": + return PRDC(self, device, enable_ddp) + + @staticmethod + def get_type() -> str: + return "prdc" + + +class PRDC(QLearningAlgoBase[PRDCImpl, PRDCConfig]): + _nbsr = NearestNeighbors(n_neighbors=1, algorithm="auto", n_jobs=-1) + + def inner_create_impl(self, observation_shape: Shape, action_size: int) -> None: + policy = create_deterministic_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + device=self._device, + enable_ddp=self._enable_ddp, + ) + targ_policy = create_deterministic_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + device=self._device, + enable_ddp=self._enable_ddp, + ) + q_funcs, q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + enable_ddp=self._enable_ddp, + ) + targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( + observation_shape, + action_size, + self._config.critic_encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + enable_ddp=self._enable_ddp, + ) + + actor_optim = self._config.actor_optim_factory.create( + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=self.compiled, + ) + critic_optim = self._config.critic_optim_factory.create( + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=self.compiled, + ) + + modules = DDPGModules( + policy=policy, + targ_policy=targ_policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + actor_optim=actor_optim, + critic_optim=critic_optim, + ) + + self._impl = PRDCImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=self._config.gamma, + tau=self._config.tau, + target_smoothing_sigma=self._config.target_smoothing_sigma, + target_smoothing_clip=self._config.target_smoothing_clip, + alpha=self._config.alpha, + beta=self._config.beta, + update_actor_interval=self._config.update_actor_interval, + compiled=self.compiled, + nbsr=self._nbsr, + device=self._device, + ) + + def fit( + self, + dataset: ReplayBufferBase, + n_steps: int, + n_steps_per_epoch: int = 10000, + experiment_name: Optional[str] = None, + with_timestamp: bool = True, + logging_steps: int = 500, + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, + logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), + show_progress: bool = True, + save_interval: int = 1, + evaluators: Optional[dict[str, EvaluatorProtocol]] = None, + callback: Optional[Callable[[Self, int, int], None]] = None, + epoch_callback: Optional[Callable[[Self, int, int], None]] = None, + ) -> list[tuple[int, dict[str, float]]]: + observations = [] + actions = [] + for episode in dataset.buffer.episodes: + for i in range(episode.transition_count): + transition = dataset.transition_picker(episode, i) + observations.append(np.reshape(transition.observation, (1, -1))) + actions.append(np.reshape(transition.action, (1, -1))) + observations = np.concatenate(observations, axis=0) + actions = np.concatenate(actions, axis=0) + + build_scalers_with_transition_picker(self, dataset) + if self.observation_scaler and self.observation_scaler.built: + observations = ( + self.observation_scaler.transform( + torch.tensor(observations, device=self._device) + ) + .cpu() + .numpy() + ) + + if self.action_scaler and self.action_scaler.built: + actions = ( + self.action_scaler.transform(torch.tensor(actions, device=self._device)) + .cpu() + .numpy() + ) + + self._nbsr.fit( + np.concatenate( + [np.multiply(observations, self._config.beta), actions], + axis=1, + ) + ) + + return super().fit( + dataset=dataset, + n_steps=n_steps, + n_steps_per_epoch=n_steps_per_epoch, + logging_steps=logging_steps, + logging_strategy=logging_strategy, + experiment_name=experiment_name, + with_timestamp=with_timestamp, + logger_adapter=logger_adapter, + show_progress=show_progress, + save_interval=save_interval, + evaluators=evaluators, + callback=callback, + epoch_callback=epoch_callback, + ) + + def get_action_type(self) -> ActionSpace: + return ActionSpace.CONTINUOUS + + +register_learnable(PRDCConfig) diff --git a/d3rlpy/algos/qlearning/torch/__init__.py b/d3rlpy/algos/qlearning/torch/__init__.py index bf545348..66253ea0 100644 --- a/d3rlpy/algos/qlearning/torch/__init__.py +++ b/d3rlpy/algos/qlearning/torch/__init__.py @@ -9,6 +9,7 @@ from .dqn_impl import * from .iql_impl import * from .plas_impl import * +from .prdc_impl import * from .rebrac_impl import * from .sac_impl import * from .td3_impl import * diff --git a/d3rlpy/algos/qlearning/torch/prdc_impl.py b/d3rlpy/algos/qlearning/torch/prdc_impl.py new file mode 100644 index 00000000..0ca2e323 --- /dev/null +++ b/d3rlpy/algos/qlearning/torch/prdc_impl.py @@ -0,0 +1,84 @@ +# pylint: disable=too-many-ancestors +import dataclasses + +import torch +from sklearn.neighbors import NearestNeighbors + +from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder +from ....torch_utility import TorchMiniBatch +from ....types import Shape +from .ddpg_impl import DDPGBaseActorLoss, DDPGModules +from .td3_impl import TD3Impl + +__all__ = ["PRDCImpl"] + + +@dataclasses.dataclass(frozen=True) +class PRDCActorLoss(DDPGBaseActorLoss): + dc_loss: torch.Tensor + + +class PRDCImpl(TD3Impl): + _alpha: float + _beta: float + _nbsr: NearestNeighbors + + def __init__( + self, + observation_shape: Shape, + action_size: int, + modules: DDPGModules, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + gamma: float, + tau: float, + target_smoothing_sigma: float, + target_smoothing_clip: float, + alpha: float, + beta: float, + update_actor_interval: int, + compiled: bool, + nbsr: NearestNeighbors, + device: str, + ): + super().__init__( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=gamma, + tau=tau, + target_smoothing_sigma=target_smoothing_sigma, + target_smoothing_clip=target_smoothing_clip, + update_actor_interval=update_actor_interval, + compiled=compiled, + device=device, + ) + self._alpha = alpha + self._beta = beta + self._nbsr = nbsr + + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> PRDCActorLoss: + q_t = self._q_func_forwarder.compute_expected_q( + batch.observations, action.squashed_mu, "none" + )[0] + lam = self._alpha / (q_t.abs().mean()).detach() + key = ( + torch.cat( + [torch.mul(batch.observations, self._beta), action.squashed_mu], dim=-1 + ) + .detach() + .cpu() + .numpy() + ) + idx = self._nbsr.kneighbors(key, n_neighbors=1, return_distance=False) + nearest_neightbour = torch.tensor( + self._nbsr._fit_X[idx][:, :, -self.action_size :], + device=self.device, + dtype=action.squashed_mu.dtype, + ).squeeze(dim=1) + dc_loss = torch.nn.functional.mse_loss(action.squashed_mu, nearest_neightbour) + return PRDCActorLoss(actor_loss=lam * -q_t.mean() + dc_loss, dc_loss=dc_loss) diff --git a/mypy.ini b/mypy.ini index 4e09045b..1d011dbc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -71,3 +71,6 @@ follow_imports_for_stubs = True ignore_missing_imports = True follow_imports = skip follow_imports_for_stubs = True + +[mypy-sklearn.*] +ignore_missing_imports = True diff --git a/reproductions/offline/prdc.py b/reproductions/offline/prdc.py new file mode 100644 index 00000000..c941569a --- /dev/null +++ b/reproductions/offline/prdc.py @@ -0,0 +1,43 @@ +import argparse + +import d3rlpy + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, default="hopper-medium-v2") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--compile", action="store_true") + args = parser.parse_args() + + dataset, env = d3rlpy.datasets.get_dataset(args.dataset) + + # fix seed + d3rlpy.seed(args.seed) + d3rlpy.envs.seed_env(env, args.seed) + + prdc = d3rlpy.algos.PRDCConfig( + actor_learning_rate=3e-4, + critic_learning_rate=3e-4, + batch_size=256, + target_smoothing_sigma=0.2, + target_smoothing_clip=0.5, + alpha=2.5, + update_actor_interval=2, + observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), + compile_graph=args.compile, + ).create(device=args.gpu) + + prdc.fit( + dataset, + n_steps=500000, + n_steps_per_epoch=1000, + save_interval=10, + evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, + experiment_name=f"PRDC_{args.dataset}_{args.seed}", + ) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 276f1d3c..0505e347 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ "colorama", "dataclasses-json", "gymnasium>=1.0.0", + "scikit-learn", ], packages=find_packages(exclude=["tests*"]), python_requires=">=3.9.0", diff --git a/tests/algos/qlearning/test_prdc.py b/tests/algos/qlearning/test_prdc.py new file mode 100644 index 00000000..2ce84bdf --- /dev/null +++ b/tests/algos/qlearning/test_prdc.py @@ -0,0 +1,40 @@ +from typing import Optional + +import pytest + +from d3rlpy.algos.qlearning.prdc import PRDCConfig +from d3rlpy.models import ( + MeanQFunctionFactory, + QFunctionFactory, + QRQFunctionFactory, +) +from d3rlpy.types import Shape + +from ...models.torch.model_test import DummyEncoderFactory +from ...testing_utils import create_scaler_tuple +from .algo_test import algo_tester + + +@pytest.mark.parametrize("observation_shape", [(100,), (17,)]) +@pytest.mark.parametrize( + "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] +) +@pytest.mark.parametrize("scalers", [None, "min_max"]) +def test_prdc( + observation_shape: Shape, + q_func_factory: QFunctionFactory, + scalers: Optional[str], +) -> None: + observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( + scalers, observation_shape + ) + config = PRDCConfig( + actor_encoder_factory=DummyEncoderFactory(), + critic_encoder_factory=DummyEncoderFactory(), + q_func_factory=q_func_factory, + observation_scaler=observation_scaler, + action_scaler=action_scaler, + reward_scaler=reward_scaler, + ) + prdc = config.create() + algo_tester(prdc, observation_shape) # type: ignore