Skip to content

Commit

Permalink
Fix errors from lint
Browse files Browse the repository at this point in the history
  • Loading branch information
liyc-ai committed Nov 8, 2024
1 parent eb7ee38 commit 1f8144b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 20 deletions.
35 changes: 21 additions & 14 deletions d3rlpy/algos/qlearning/prdc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Callable, Generator, Optional
from typing import Callable, Optional

import numpy as np
import torch
Expand All @@ -11,7 +11,10 @@
from ...dataset import ReplayBufferBase
from ...logging import FileAdapterFactory, LoggerAdapterFactory
from ...metrics import EvaluatorProtocol
from ...models.builders import create_continuous_q_function, create_deterministic_policy
from ...models.builders import (
create_continuous_q_function,
create_deterministic_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
Expand Down Expand Up @@ -188,43 +191,47 @@ def fit(
dataset: ReplayBufferBase,
n_steps: int,
n_steps_per_epoch: int = 10000,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> Generator[tuple[int, dict[str, float]], None, None]:
) -> list[tuple[int, dict[str, float]]]:
observations = []
actions = []
for episode in dataset.buffer.episodes:
for i in range(episode.transition_count):
transition = dataset.transition_picker(episode, i)
observations.append(transition.observation.reshape(1, -1))
actions.append(transition.action.reshape(1, -1))
observations.append(np.reshape(transition.observation, (1, -1)))
actions.append(np.reshape(transition.action, (1, -1)))
observations = np.concatenate(observations, axis=0)
actions = np.concatenate(actions, axis=0)

build_scalers_with_transition_picker(self, dataset)
if self.observation_scaler and self.observation_scaler.built:
observations = self.observation_scaler.transform(
torch.tensor(observations, device=self._device)
observations = (
self.observation_scaler.transform(
torch.tensor(observations, device=self._device)
)
.cpu()
.numpy()
)
observations = observations.cpu().numpy()

if self.action_scaler and self.action_scaler.built:
actions = self.action_scaler.transform(
torch.tensor(actions, device=self._device)
actions = (
self.action_scaler.transform(torch.tensor(actions, device=self._device))
.cpu()
.numpy()
)
actions = actions.cpu().numpy()

self._nbsr.fit(
np.concatenate(
[self._config.beta * observations, actions],
[np.multiply(observations, self._config.beta), actions],
axis=1,
)
)
Expand Down
13 changes: 8 additions & 5 deletions d3rlpy/algos/qlearning/torch/prdc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .ddpg_impl import DDPGBaseActorLoss, DDPGModules
from .td3_plus_bc_impl import TD3PlusBCImpl
from .td3_impl import TD3Impl

__all__ = ["PRDCImpl"]

Expand All @@ -18,8 +18,9 @@ class PRDCActorLoss(DDPGBaseActorLoss):
dc_loss: torch.Tensor


class PRDCImpl(TD3PlusBCImpl):
_beta: float = 2.0
class PRDCImpl(TD3Impl):
_alpha: float
_beta: float
_nbsr: NearestNeighbors

def __init__(
Expand Down Expand Up @@ -50,11 +51,11 @@ def __init__(
tau=tau,
target_smoothing_sigma=target_smoothing_sigma,
target_smoothing_clip=target_smoothing_clip,
alpha=alpha,
update_actor_interval=update_actor_interval,
compiled=compiled,
device=device,
)
self._alpha = alpha
self._beta = beta
self._nbsr = nbsr

Expand All @@ -66,7 +67,9 @@ def compute_actor_loss(
)[0]
lam = self._alpha / (q_t.abs().mean()).detach()
key = (
torch.cat([self._beta * batch.observations, action.squashed_mu], dim=-1)
torch.cat(
[torch.mul(batch.observations, self._beta), action.squashed_mu], dim=-1
)
.detach()
.cpu()
.numpy()
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ follow_imports_for_stubs = True
ignore_missing_imports = True
follow_imports = skip
follow_imports_for_stubs = True

[mypy-sklearn.*]
ignore_missing_imports = True
6 changes: 5 additions & 1 deletion tests/algos/qlearning/test_prdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import pytest

from d3rlpy.algos.qlearning.prdc import PRDCConfig
from d3rlpy.models import MeanQFunctionFactory, QFunctionFactory, QRQFunctionFactory
from d3rlpy.models import (
MeanQFunctionFactory,
QFunctionFactory,
QRQFunctionFactory,
)
from d3rlpy.types import Shape

from ...models.torch.model_test import DummyEncoderFactory
Expand Down

0 comments on commit 1f8144b

Please sign in to comment.