Skip to content

Commit

Permalink
Add ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent 2c67af3 commit c72f388
Show file tree
Hide file tree
Showing 74 changed files with 346 additions and 992 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/format_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ jobs:
pip install Cython numpy
pip install -e .
pip install -r dev.requirements.txt
- name: Check format
run: |
./scripts/format
- name: Linter
- name: Static analysis
run: |
./scripts/lint
Expand Down
10 changes: 2 additions & 8 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,10 @@ $ ./scripts/test
```

### Coding style check
This repository is styled with [black](https://github.com/psf/black) formatter.
Also, [isort](https://github.com/PyCQA/isort) is used to format package imports.
This repository is styled and analyzed with [Ruff](https://docs.astral.sh/ruff/).
[docformatter](https://github.com/PyCQA/docformatter) is additionally used to format docstrings.
```
$ ./scripts/format
```

### Linter
This repository is fully type-annotated and checked by [mypy](https://github.com/python/mypy).
Also, [pylint](https://github.com/PyCQA/pylint) checks code consistency.
Before you submit your PR, please execute this command:
```
$ ./scripts/lint
```
17 changes: 7 additions & 10 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from collections import defaultdict
from typing import (
Callable,
Dict,
Generator,
Generic,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)

Expand Down Expand Up @@ -67,13 +64,13 @@

class QLearningAlgoImplBase(ImplBase):
@train_api
def update(self, batch: TorchMiniBatch, grad_step: int) -> Dict[str, float]:
def update(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]:
return self.inner_update(batch, grad_step)

@abstractmethod
def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
pass

@eval_api
Expand Down Expand Up @@ -382,10 +379,10 @@ def fit(
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> List[Tuple[int, Dict[str, float]]]:
) -> list[tuple[int, dict[str, float]]]:
"""Trains with given dataset.
.. code-block:: python
Expand Down Expand Up @@ -448,10 +445,10 @@ def fitter(
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> Generator[Tuple[int, Dict[str, float]], None, None]:
) -> Generator[tuple[int, dict[str, float]], None, None]:
"""Iterate over epochs steps to train with the given dataset. At each
iteration algo methods and properties can be changed or queried.
Expand Down Expand Up @@ -859,7 +856,7 @@ def collect(

return buffer

def update(self, batch: TransitionMiniBatch) -> Dict[str, float]:
def update(self, batch: TransitionMiniBatch) -> dict[str, float]:
"""Update parameters with mini-batch of data.
Args:
Expand Down
13 changes: 8 additions & 5 deletions d3rlpy/algos/qlearning/random_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
from typing import Dict

import numpy as np

Expand Down Expand Up @@ -35,7 +34,9 @@ class RandomPolicyConfig(LearnableConfig):
distribution: str = "uniform"
normal_std: float = 1.0

def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "RandomPolicy": # type: ignore
def create( # type: ignore
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "RandomPolicy":
return RandomPolicy(self)

@staticmethod
Expand Down Expand Up @@ -83,7 +84,7 @@ def sample_action(self, x: Observation) -> NDArray:
def predict_value(self, x: Observation, action: NDArray) -> NDArray:
raise NotImplementedError

def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
def inner_update(self, batch: TorchMiniBatch) -> dict[str, float]:
raise NotImplementedError

def get_action_type(self) -> ActionSpace:
Expand All @@ -98,7 +99,9 @@ class DiscreteRandomPolicyConfig(LearnableConfig):
``fit`` and ``fit_online`` methods will raise exceptions.
"""

def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "DiscreteRandomPolicy": # type: ignore
def create( # type: ignore
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "DiscreteRandomPolicy":
return DiscreteRandomPolicy(self)

@staticmethod
Expand Down Expand Up @@ -128,7 +131,7 @@ def sample_action(self, x: Observation) -> NDArray:
def predict_value(self, x: Observation, action: NDArray) -> NDArray:
raise NotImplementedError

def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
def inner_update(self, batch: TorchMiniBatch) -> dict[str, float]:
raise NotImplementedError

def get_action_type(self) -> ActionSpace:
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/torch/bc_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, Union
from typing import Callable, Union

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -60,7 +60,7 @@ def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss:
loss.loss.backward()
return loss

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_imitator_grad(batch)
self._modules.optim.step()
return asdict_as_float(loss)
Expand All @@ -81,7 +81,7 @@ def inner_predict_value(

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
return self.update_imitator(batch)


Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import math
from typing import Callable, Dict, cast
from typing import Callable, cast

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -50,7 +50,7 @@ class BCQModules(DDPGBaseModules):

class BCQImpl(DDPGBaseImpl):
_modules: BCQModules
_compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]]
_compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]]
_lam: float
_n_action_samples: int
_action_flexibility: float
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss:

def compute_imitator_grad(
self, batch: TorchMiniBatch
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
self._modules.vae_optim.zero_grad()
loss = compute_vae_error(
vae_encoder=self._modules.vae_encoder,
Expand All @@ -136,7 +136,7 @@ def compute_imitator_grad(
loss.backward()
return {"loss": loss}

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_imitator_grad(batch)
self._modules.vae_optim.step()
return {"vae_loss": float(loss["loss"].cpu().detach().numpy())}
Expand Down Expand Up @@ -214,7 +214,7 @@ def update_actor_target(self) -> None:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = {}

metrics.update(self.update_imitator(batch))
Expand Down
16 changes: 8 additions & 8 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Callable, Dict, Optional
from typing import Callable, Optional

import torch

Expand Down Expand Up @@ -61,9 +61,9 @@ class BEARActorLoss(SACActorLoss):
class BEARImpl(SACImpl):
_modules: BEARModules
_compute_warmup_actor_grad: Callable[
[TorchMiniBatch], Dict[str, torch.Tensor]
[TorchMiniBatch], dict[str, torch.Tensor]
]
_compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]]
_compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]]
_alpha_threshold: float
_lam: float
_n_action_samples: int
Expand Down Expand Up @@ -143,13 +143,13 @@ def compute_actor_loss(

def compute_warmup_actor_grad(
self, batch: TorchMiniBatch
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
self._modules.actor_optim.zero_grad()
loss = self._compute_mmd_loss(batch.observations)
loss.backward()
return {"loss": loss}

def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
def warmup_actor(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_warmup_actor_grad(batch)
self._modules.actor_optim.step()
return {"actor_loss": float(loss["loss"].cpu().detach().numpy())}
Expand All @@ -161,13 +161,13 @@ def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor:

def compute_imitator_grad(
self, batch: TorchMiniBatch
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
self._modules.vae_optim.zero_grad()
loss = self.compute_imitator_loss(batch)
loss.backward()
return {"loss": loss}

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_imitator_grad(batch)
self._modules.vae_optim.step()
return {"imitator_loss": float(loss["loss"].cpu().detach().numpy())}
Expand Down Expand Up @@ -301,7 +301,7 @@ def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = {}
metrics.update(self.update_imitator(batch))
metrics.update(self.update_critic(batch))
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/algos/qlearning/torch/cal_ql_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Tuple

import torch

Expand All @@ -14,7 +13,7 @@ def _compute_policy_is_values(
policy_obs: TorchObservation,
value_obs: TorchObservation,
returns_to_go: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
values, log_probs = super()._compute_policy_is_values(
policy_obs=policy_obs,
value_obs=value_obs,
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import math
from typing import Optional, Tuple
from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -119,7 +119,7 @@ def _compute_policy_is_values(
policy_obs: TorchObservation,
value_obs: TorchObservation,
returns_to_go: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
return sample_q_values_with_policy(
policy=self._modules.policy,
q_func_forwarder=self._q_func_forwarder,
Expand All @@ -131,7 +131,7 @@ def _compute_policy_is_values(

def _compute_random_is_values(
self, obs: TorchObservation
) -> Tuple[torch.Tensor, float]:
) -> tuple[torch.Tensor, float]:
# (batch, observation) -> (batch, n, observation)
repeated_obs = expand_and_repeat_recursively(
obs, self._n_action_samples
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/algos/qlearning/torch/crr_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
from typing import Dict

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -186,7 +185,7 @@ def update_actor_target(self) -> None:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = {}
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch))
Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict
from typing import Callable

import torch
from torch import nn
Expand Down Expand Up @@ -105,7 +105,7 @@ def compute_critic_grad(self, batch: TorchMiniBatch) -> DDPGBaseCriticLoss:
loss.critic_loss.backward()
return loss

def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_critic(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_critic_grad(batch)
self._modules.critic_optim.step()
return asdict_as_float(loss)
Expand All @@ -130,7 +130,7 @@ def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss:
loss.actor_loss.backward()
return loss

def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_actor(self, batch: TorchMiniBatch) -> dict[str, float]:
# Q function should be inference mode for stability
self._modules.q_funcs.eval()
loss = self._compute_actor_grad(batch)
Expand All @@ -139,7 +139,7 @@ def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = {}
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch))
Expand Down Expand Up @@ -241,7 +241,7 @@ def update_actor_target(self) -> None:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = super().inner_update(batch, grad_step)
self.update_actor_target()
return metrics
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/torch/dqn_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Callable, Dict
from typing import Callable

import torch
from torch import nn
Expand Down Expand Up @@ -79,7 +79,7 @@ def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
loss = self._compute_grad(batch)
self._modules.optim.step()
if grad_step % self._target_update_interval == 0:
Expand Down
Loading

0 comments on commit c72f388

Please sign in to comment.