Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to ruff #429

Merged
merged 3 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/format_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ jobs:
pip install Cython numpy
pip install -e .
pip install -r dev.requirements.txt
- name: Check format
run: |
./scripts/format
- name: Linter
- name: Static analysis
run: |
./scripts/lint

Expand Down
10 changes: 2 additions & 8 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,10 @@ $ ./scripts/test
```

### Coding style check
This repository is styled with [black](https://github.com/psf/black) formatter.
Also, [isort](https://github.com/PyCQA/isort) is used to format package imports.
This repository is styled and analyzed with [Ruff](https://docs.astral.sh/ruff/).
[docformatter](https://github.com/PyCQA/docformatter) is additionally used to format docstrings.
```
$ ./scripts/format
```

### Linter
This repository is fully type-annotated and checked by [mypy](https://github.com/python/mypy).
Also, [pylint](https://github.com/PyCQA/pylint) checks code consistency.
Before you submit your PR, please execute this command:
```
$ ./scripts/lint
```
17 changes: 7 additions & 10 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from collections import defaultdict
from typing import (
Callable,
Dict,
Generator,
Generic,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)

Expand Down Expand Up @@ -67,13 +64,13 @@

class QLearningAlgoImplBase(ImplBase):
@train_api
def update(self, batch: TorchMiniBatch, grad_step: int) -> Dict[str, float]:
def update(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]:
return self.inner_update(batch, grad_step)

@abstractmethod
def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
pass

@eval_api
Expand Down Expand Up @@ -382,10 +379,10 @@ def fit(
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> List[Tuple[int, Dict[str, float]]]:
) -> list[tuple[int, dict[str, float]]]:
"""Trains with given dataset.

.. code-block:: python
Expand Down Expand Up @@ -448,10 +445,10 @@ def fitter(
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> Generator[Tuple[int, Dict[str, float]], None, None]:
) -> Generator[tuple[int, dict[str, float]], None, None]:
"""Iterate over epochs steps to train with the given dataset. At each
iteration algo methods and properties can be changed or queried.

Expand Down Expand Up @@ -859,7 +856,7 @@ def collect(

return buffer

def update(self, batch: TransitionMiniBatch) -> Dict[str, float]:
def update(self, batch: TransitionMiniBatch) -> dict[str, float]:
"""Update parameters with mini-batch of data.

Args:
Expand Down
13 changes: 8 additions & 5 deletions d3rlpy/algos/qlearning/random_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
from typing import Dict

import numpy as np

Expand Down Expand Up @@ -35,7 +34,9 @@ class RandomPolicyConfig(LearnableConfig):
distribution: str = "uniform"
normal_std: float = 1.0

def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "RandomPolicy": # type: ignore
def create( # type: ignore
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "RandomPolicy":
return RandomPolicy(self)

@staticmethod
Expand Down Expand Up @@ -83,7 +84,7 @@ def sample_action(self, x: Observation) -> NDArray:
def predict_value(self, x: Observation, action: NDArray) -> NDArray:
raise NotImplementedError

def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
def inner_update(self, batch: TorchMiniBatch) -> dict[str, float]:
raise NotImplementedError

def get_action_type(self) -> ActionSpace:
Expand All @@ -98,7 +99,9 @@ class DiscreteRandomPolicyConfig(LearnableConfig):
``fit`` and ``fit_online`` methods will raise exceptions.
"""

def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "DiscreteRandomPolicy": # type: ignore
def create( # type: ignore
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "DiscreteRandomPolicy":
return DiscreteRandomPolicy(self)

@staticmethod
Expand Down Expand Up @@ -128,7 +131,7 @@ def sample_action(self, x: Observation) -> NDArray:
def predict_value(self, x: Observation, action: NDArray) -> NDArray:
raise NotImplementedError

def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
def inner_update(self, batch: TorchMiniBatch) -> dict[str, float]:
raise NotImplementedError

def get_action_type(self) -> ActionSpace:
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/torch/bc_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, Union
from typing import Callable, Union

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -60,7 +60,7 @@ def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss:
loss.loss.backward()
return loss

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_imitator_grad(batch)
self._modules.optim.step()
return asdict_as_float(loss)
Expand All @@ -81,7 +81,7 @@ def inner_predict_value(

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
return self.update_imitator(batch)


Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import math
from typing import Callable, Dict, cast
from typing import Callable, cast

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -50,7 +50,7 @@ class BCQModules(DDPGBaseModules):

class BCQImpl(DDPGBaseImpl):
_modules: BCQModules
_compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]]
_compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]]
_lam: float
_n_action_samples: int
_action_flexibility: float
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss:

def compute_imitator_grad(
self, batch: TorchMiniBatch
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
self._modules.vae_optim.zero_grad()
loss = compute_vae_error(
vae_encoder=self._modules.vae_encoder,
Expand All @@ -136,7 +136,7 @@ def compute_imitator_grad(
loss.backward()
return {"loss": loss}

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_imitator_grad(batch)
self._modules.vae_optim.step()
return {"vae_loss": float(loss["loss"].cpu().detach().numpy())}
Expand Down Expand Up @@ -214,7 +214,7 @@ def update_actor_target(self) -> None:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = {}

metrics.update(self.update_imitator(batch))
Expand Down
16 changes: 8 additions & 8 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Callable, Dict, Optional
from typing import Callable, Optional

import torch

Expand Down Expand Up @@ -61,9 +61,9 @@ class BEARActorLoss(SACActorLoss):
class BEARImpl(SACImpl):
_modules: BEARModules
_compute_warmup_actor_grad: Callable[
[TorchMiniBatch], Dict[str, torch.Tensor]
[TorchMiniBatch], dict[str, torch.Tensor]
]
_compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]]
_compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]]
_alpha_threshold: float
_lam: float
_n_action_samples: int
Expand Down Expand Up @@ -143,13 +143,13 @@ def compute_actor_loss(

def compute_warmup_actor_grad(
self, batch: TorchMiniBatch
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
self._modules.actor_optim.zero_grad()
loss = self._compute_mmd_loss(batch.observations)
loss.backward()
return {"loss": loss}

def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
def warmup_actor(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_warmup_actor_grad(batch)
self._modules.actor_optim.step()
return {"actor_loss": float(loss["loss"].cpu().detach().numpy())}
Expand All @@ -161,13 +161,13 @@ def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor:

def compute_imitator_grad(
self, batch: TorchMiniBatch
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
self._modules.vae_optim.zero_grad()
loss = self.compute_imitator_loss(batch)
loss.backward()
return {"loss": loss}

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_imitator_grad(batch)
self._modules.vae_optim.step()
return {"imitator_loss": float(loss["loss"].cpu().detach().numpy())}
Expand Down Expand Up @@ -301,7 +301,7 @@ def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = {}
metrics.update(self.update_imitator(batch))
metrics.update(self.update_critic(batch))
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/algos/qlearning/torch/cal_ql_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Tuple

import torch

Expand All @@ -14,7 +13,7 @@ def _compute_policy_is_values(
policy_obs: TorchObservation,
value_obs: TorchObservation,
returns_to_go: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
values, log_probs = super()._compute_policy_is_values(
policy_obs=policy_obs,
value_obs=value_obs,
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import math
from typing import Optional, Tuple
from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -119,7 +119,7 @@ def _compute_policy_is_values(
policy_obs: TorchObservation,
value_obs: TorchObservation,
returns_to_go: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
return sample_q_values_with_policy(
policy=self._modules.policy,
q_func_forwarder=self._q_func_forwarder,
Expand All @@ -131,7 +131,7 @@ def _compute_policy_is_values(

def _compute_random_is_values(
self, obs: TorchObservation
) -> Tuple[torch.Tensor, float]:
) -> tuple[torch.Tensor, float]:
# (batch, observation) -> (batch, n, observation)
repeated_obs = expand_and_repeat_recursively(
obs, self._n_action_samples
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/algos/qlearning/torch/crr_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
from typing import Dict

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -186,7 +185,7 @@ def update_actor_target(self) -> None:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = {}
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch))
Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict
from typing import Callable

import torch
from torch import nn
Expand Down Expand Up @@ -105,7 +105,7 @@ def compute_critic_grad(self, batch: TorchMiniBatch) -> DDPGBaseCriticLoss:
loss.critic_loss.backward()
return loss

def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_critic(self, batch: TorchMiniBatch) -> dict[str, float]:
loss = self._compute_critic_grad(batch)
self._modules.critic_optim.step()
return asdict_as_float(loss)
Expand All @@ -130,7 +130,7 @@ def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss:
loss.actor_loss.backward()
return loss

def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_actor(self, batch: TorchMiniBatch) -> dict[str, float]:
# Q function should be inference mode for stability
self._modules.q_funcs.eval()
loss = self._compute_actor_grad(batch)
Expand All @@ -139,7 +139,7 @@ def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = {}
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch))
Expand Down Expand Up @@ -241,7 +241,7 @@ def update_actor_target(self) -> None:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
metrics = super().inner_update(batch, grad_step)
self.update_actor_target()
return metrics
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/torch/dqn_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Callable, Dict
from typing import Callable

import torch
from torch import nn
Expand Down Expand Up @@ -79,7 +79,7 @@ def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss:

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
) -> dict[str, float]:
loss = self._compute_grad(batch)
self._modules.optim.step()
if grad_step % self._target_update_interval == 0:
Expand Down
Loading
Loading