diff --git a/d3rlpy/algos/qlearning/prdc.py b/d3rlpy/algos/qlearning/prdc.py index 89f8234e..2ac01a14 100644 --- a/d3rlpy/algos/qlearning/prdc.py +++ b/d3rlpy/algos/qlearning/prdc.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Callable, Generator, Optional +from typing import Callable, Optional import numpy as np import torch @@ -11,7 +11,10 @@ from ...dataset import ReplayBufferBase from ...logging import FileAdapterFactory, LoggerAdapterFactory from ...metrics import EvaluatorProtocol -from ...models.builders import create_continuous_q_function, create_deterministic_policy +from ...models.builders import ( + create_continuous_q_function, + create_deterministic_policy, +) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.q_functions import QFunctionFactory, make_q_func_field from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field @@ -188,43 +191,47 @@ def fit( dataset: ReplayBufferBase, n_steps: int, n_steps_per_epoch: int = 10000, - logging_steps: int = 500, - logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, experiment_name: Optional[str] = None, with_timestamp: bool = True, + logging_steps: int = 500, + logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, save_interval: int = 1, evaluators: Optional[dict[str, EvaluatorProtocol]] = None, callback: Optional[Callable[[Self, int, int], None]] = None, epoch_callback: Optional[Callable[[Self, int, int], None]] = None, - ) -> Generator[tuple[int, dict[str, float]], None, None]: + ) -> list[tuple[int, dict[str, float]]]: observations = [] actions = [] for episode in dataset.buffer.episodes: for i in range(episode.transition_count): transition = dataset.transition_picker(episode, i) - observations.append(transition.observation.reshape(1, -1)) - actions.append(transition.action.reshape(1, -1)) + observations.append(np.reshape(transition.observation, (1, -1))) + actions.append(np.reshape(transition.action, (1, -1))) observations = np.concatenate(observations, axis=0) actions = np.concatenate(actions, axis=0) build_scalers_with_transition_picker(self, dataset) if self.observation_scaler and self.observation_scaler.built: - observations = self.observation_scaler.transform( - torch.tensor(observations, device=self._device) + observations = ( + self.observation_scaler.transform( + torch.tensor(observations, device=self._device) + ) + .cpu() + .numpy() ) - observations = observations.cpu().numpy() if self.action_scaler and self.action_scaler.built: - actions = self.action_scaler.transform( - torch.tensor(actions, device=self._device) + actions = ( + self.action_scaler.transform(torch.tensor(actions, device=self._device)) + .cpu() + .numpy() ) - actions = actions.cpu().numpy() self._nbsr.fit( np.concatenate( - [self._config.beta * observations, actions], + [np.multiply(observations, self._config.beta), actions], axis=1, ) ) diff --git a/d3rlpy/algos/qlearning/torch/prdc_impl.py b/d3rlpy/algos/qlearning/torch/prdc_impl.py index 4add2280..0ca2e323 100644 --- a/d3rlpy/algos/qlearning/torch/prdc_impl.py +++ b/d3rlpy/algos/qlearning/torch/prdc_impl.py @@ -8,7 +8,7 @@ from ....torch_utility import TorchMiniBatch from ....types import Shape from .ddpg_impl import DDPGBaseActorLoss, DDPGModules -from .td3_plus_bc_impl import TD3PlusBCImpl +from .td3_impl import TD3Impl __all__ = ["PRDCImpl"] @@ -18,8 +18,9 @@ class PRDCActorLoss(DDPGBaseActorLoss): dc_loss: torch.Tensor -class PRDCImpl(TD3PlusBCImpl): - _beta: float = 2.0 +class PRDCImpl(TD3Impl): + _alpha: float + _beta: float _nbsr: NearestNeighbors def __init__( @@ -50,11 +51,11 @@ def __init__( tau=tau, target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, - alpha=alpha, update_actor_interval=update_actor_interval, compiled=compiled, device=device, ) + self._alpha = alpha self._beta = beta self._nbsr = nbsr @@ -66,7 +67,9 @@ def compute_actor_loss( )[0] lam = self._alpha / (q_t.abs().mean()).detach() key = ( - torch.cat([self._beta * batch.observations, action.squashed_mu], dim=-1) + torch.cat( + [torch.mul(batch.observations, self._beta), action.squashed_mu], dim=-1 + ) .detach() .cpu() .numpy() diff --git a/mypy.ini b/mypy.ini index 4e09045b..1d011dbc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -71,3 +71,6 @@ follow_imports_for_stubs = True ignore_missing_imports = True follow_imports = skip follow_imports_for_stubs = True + +[mypy-sklearn.*] +ignore_missing_imports = True diff --git a/tests/algos/qlearning/test_prdc.py b/tests/algos/qlearning/test_prdc.py index 0c8d87f5..2ce84bdf 100644 --- a/tests/algos/qlearning/test_prdc.py +++ b/tests/algos/qlearning/test_prdc.py @@ -3,7 +3,11 @@ import pytest from d3rlpy.algos.qlearning.prdc import PRDCConfig -from d3rlpy.models import MeanQFunctionFactory, QFunctionFactory, QRQFunctionFactory +from d3rlpy.models import ( + MeanQFunctionFactory, + QFunctionFactory, + QRQFunctionFactory, +) from d3rlpy.types import Shape from ...models.torch.model_test import DummyEncoderFactory