diff --git a/.github/workflows/format_check.yml b/.github/workflows/format_check.yml index 9601982d..0e38aa56 100644 --- a/.github/workflows/format_check.yml +++ b/.github/workflows/format_check.yml @@ -25,10 +25,7 @@ jobs: pip install Cython numpy pip install -e . pip install -r dev.requirements.txt - - name: Check format - run: | - ./scripts/format - - name: Linter + - name: Static analysis run: | ./scripts/lint diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8a826dfe..3dcb9bea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,16 +29,10 @@ $ ./scripts/test ``` ### Coding style check -This repository is styled with [black](https://github.com/psf/black) formatter. -Also, [isort](https://github.com/PyCQA/isort) is used to format package imports. +This repository is styled and analyzed with [Ruff](https://docs.astral.sh/ruff/). [docformatter](https://github.com/PyCQA/docformatter) is additionally used to format docstrings. -``` -$ ./scripts/format -``` - -### Linter This repository is fully type-annotated and checked by [mypy](https://github.com/python/mypy). -Also, [pylint](https://github.com/PyCQA/pylint) checks code consistency. +Before you submit your PR, please execute this command: ``` $ ./scripts/lint ``` diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 1dc4988f..52b1013a 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -2,13 +2,10 @@ from collections import defaultdict from typing import ( Callable, - Dict, Generator, Generic, - List, Optional, Sequence, - Tuple, TypeVar, ) @@ -67,13 +64,13 @@ class QLearningAlgoImplBase(ImplBase): @train_api - def update(self, batch: TorchMiniBatch, grad_step: int) -> Dict[str, float]: + def update(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]: return self.inner_update(batch, grad_step) @abstractmethod def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: pass @eval_api @@ -382,10 +379,10 @@ def fit( logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, save_interval: int = 1, - evaluators: Optional[Dict[str, EvaluatorProtocol]] = None, + evaluators: Optional[dict[str, EvaluatorProtocol]] = None, callback: Optional[Callable[[Self, int, int], None]] = None, epoch_callback: Optional[Callable[[Self, int, int], None]] = None, - ) -> List[Tuple[int, Dict[str, float]]]: + ) -> list[tuple[int, dict[str, float]]]: """Trains with given dataset. .. code-block:: python @@ -448,10 +445,10 @@ def fitter( logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, save_interval: int = 1, - evaluators: Optional[Dict[str, EvaluatorProtocol]] = None, + evaluators: Optional[dict[str, EvaluatorProtocol]] = None, callback: Optional[Callable[[Self, int, int], None]] = None, epoch_callback: Optional[Callable[[Self, int, int], None]] = None, - ) -> Generator[Tuple[int, Dict[str, float]], None, None]: + ) -> Generator[tuple[int, dict[str, float]], None, None]: """Iterate over epochs steps to train with the given dataset. At each iteration algo methods and properties can be changed or queried. @@ -859,7 +856,7 @@ def collect( return buffer - def update(self, batch: TransitionMiniBatch) -> Dict[str, float]: + def update(self, batch: TransitionMiniBatch) -> dict[str, float]: """Update parameters with mini-batch of data. Args: diff --git a/d3rlpy/algos/qlearning/random_policy.py b/d3rlpy/algos/qlearning/random_policy.py index 0623a7c5..977cf7a5 100644 --- a/d3rlpy/algos/qlearning/random_policy.py +++ b/d3rlpy/algos/qlearning/random_policy.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict import numpy as np @@ -35,7 +34,9 @@ class RandomPolicyConfig(LearnableConfig): distribution: str = "uniform" normal_std: float = 1.0 - def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "RandomPolicy": # type: ignore + def create( # type: ignore + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "RandomPolicy": return RandomPolicy(self) @staticmethod @@ -83,7 +84,7 @@ def sample_action(self, x: Observation) -> NDArray: def predict_value(self, x: Observation, action: NDArray) -> NDArray: raise NotImplementedError - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: + def inner_update(self, batch: TorchMiniBatch) -> dict[str, float]: raise NotImplementedError def get_action_type(self) -> ActionSpace: @@ -98,7 +99,9 @@ class DiscreteRandomPolicyConfig(LearnableConfig): ``fit`` and ``fit_online`` methods will raise exceptions. """ - def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "DiscreteRandomPolicy": # type: ignore + def create( # type: ignore + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DiscreteRandomPolicy": return DiscreteRandomPolicy(self) @staticmethod @@ -128,7 +131,7 @@ def sample_action(self, x: Observation) -> NDArray: def predict_value(self, x: Observation, action: NDArray) -> NDArray: raise NotImplementedError - def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: + def inner_update(self, batch: TorchMiniBatch) -> dict[str, float]: raise NotImplementedError def get_action_type(self) -> ActionSpace: diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 85b1d5f6..319e2ba8 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -1,6 +1,6 @@ import dataclasses from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, Union +from typing import Callable, Union import torch from torch.optim import Optimizer @@ -60,7 +60,7 @@ def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss: loss.loss.backward() return loss - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]: loss = self._compute_imitator_grad(batch) self._modules.optim.step() return asdict_as_float(loss) @@ -81,7 +81,7 @@ def inner_predict_value( def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: return self.update_imitator(batch) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 896b1ab9..6c77a4bb 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Callable, Dict, cast +from typing import Callable, cast import torch import torch.nn.functional as F @@ -50,7 +50,7 @@ class BCQModules(DDPGBaseModules): class BCQImpl(DDPGBaseImpl): _modules: BCQModules - _compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] + _compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]] _lam: float _n_action_samples: int _action_flexibility: float @@ -124,7 +124,7 @@ def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss: def compute_imitator_grad( self, batch: TorchMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: self._modules.vae_optim.zero_grad() loss = compute_vae_error( vae_encoder=self._modules.vae_encoder, @@ -136,7 +136,7 @@ def compute_imitator_grad( loss.backward() return {"loss": loss} - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]: loss = self._compute_imitator_grad(batch) self._modules.vae_optim.step() return {"vae_loss": float(loss["loss"].cpu().detach().numpy())} @@ -214,7 +214,7 @@ def update_actor_target(self) -> None: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: metrics = {} metrics.update(self.update_imitator(batch)) diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 7adc6290..3957f00f 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Callable, Dict, Optional +from typing import Callable, Optional import torch @@ -61,9 +61,9 @@ class BEARActorLoss(SACActorLoss): class BEARImpl(SACImpl): _modules: BEARModules _compute_warmup_actor_grad: Callable[ - [TorchMiniBatch], Dict[str, torch.Tensor] + [TorchMiniBatch], dict[str, torch.Tensor] ] - _compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] + _compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]] _alpha_threshold: float _lam: float _n_action_samples: int @@ -143,13 +143,13 @@ def compute_actor_loss( def compute_warmup_actor_grad( self, batch: TorchMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: self._modules.actor_optim.zero_grad() loss = self._compute_mmd_loss(batch.observations) loss.backward() return {"loss": loss} - def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + def warmup_actor(self, batch: TorchMiniBatch) -> dict[str, float]: loss = self._compute_warmup_actor_grad(batch) self._modules.actor_optim.step() return {"actor_loss": float(loss["loss"].cpu().detach().numpy())} @@ -161,13 +161,13 @@ def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor: def compute_imitator_grad( self, batch: TorchMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: self._modules.vae_optim.zero_grad() loss = self.compute_imitator_loss(batch) loss.backward() return {"loss": loss} - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]: loss = self._compute_imitator_grad(batch) self._modules.vae_optim.step() return {"imitator_loss": float(loss["loss"].cpu().detach().numpy())} @@ -301,7 +301,7 @@ def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: metrics = {} metrics.update(self.update_imitator(batch)) metrics.update(self.update_critic(batch)) diff --git a/d3rlpy/algos/qlearning/torch/cal_ql_impl.py b/d3rlpy/algos/qlearning/torch/cal_ql_impl.py index 8079ed25..4d494e5d 100644 --- a/d3rlpy/algos/qlearning/torch/cal_ql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cal_ql_impl.py @@ -1,4 +1,3 @@ -from typing import Tuple import torch @@ -14,7 +13,7 @@ def _compute_policy_is_values( policy_obs: TorchObservation, value_obs: TorchObservation, returns_to_go: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: values, log_probs = super()._compute_policy_is_values( policy_obs=policy_obs, value_obs=value_obs, diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 9d986d77..f3eb7955 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -119,7 +119,7 @@ def _compute_policy_is_values( policy_obs: TorchObservation, value_obs: TorchObservation, returns_to_go: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: return sample_q_values_with_policy( policy=self._modules.policy, q_func_forwarder=self._q_func_forwarder, @@ -131,7 +131,7 @@ def _compute_policy_is_values( def _compute_random_is_values( self, obs: TorchObservation - ) -> Tuple[torch.Tensor, float]: + ) -> tuple[torch.Tensor, float]: # (batch, observation) -> (batch, n, observation) repeated_obs = expand_and_repeat_recursively( obs, self._n_action_samples diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index eda3a8b3..6f90e201 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict import torch import torch.nn.functional as F @@ -186,7 +185,7 @@ def update_actor_target(self) -> None: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: metrics = {} metrics.update(self.update_critic(batch)) metrics.update(self.update_actor(batch)) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 620b0eb9..f92fc15f 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -1,6 +1,6 @@ import dataclasses from abc import ABCMeta, abstractmethod -from typing import Callable, Dict +from typing import Callable import torch from torch import nn @@ -105,7 +105,7 @@ def compute_critic_grad(self, batch: TorchMiniBatch) -> DDPGBaseCriticLoss: loss.critic_loss.backward() return loss - def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_critic(self, batch: TorchMiniBatch) -> dict[str, float]: loss = self._compute_critic_grad(batch) self._modules.critic_optim.step() return asdict_as_float(loss) @@ -130,7 +130,7 @@ def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss: loss.actor_loss.backward() return loss - def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_actor(self, batch: TorchMiniBatch) -> dict[str, float]: # Q function should be inference mode for stability self._modules.q_funcs.eval() loss = self._compute_actor_grad(batch) @@ -139,7 +139,7 @@ def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: metrics = {} metrics.update(self.update_critic(batch)) metrics.update(self.update_actor(batch)) @@ -241,7 +241,7 @@ def update_actor_target(self) -> None: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: metrics = super().inner_update(batch, grad_step) self.update_actor_target() return metrics diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index 0c6f5ba2..859b8b19 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Callable, Dict +from typing import Callable import torch from torch import nn @@ -79,7 +79,7 @@ def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: loss = self._compute_grad(batch) self._modules.optim.step() if grad_step % self._target_update_interval == 0: diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 0fbe2bfc..66818e4a 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Callable, Dict +from typing import Callable import torch @@ -36,7 +36,7 @@ class PLASModules(DDPGBaseModules): class PLASImpl(DDPGBaseImpl): _modules: PLASModules - _compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] + _compute_imitator_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]] _lam: float _beta: float _warmup_steps: int @@ -78,7 +78,7 @@ def __init__( def compute_imitator_grad( self, batch: TorchMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: self._modules.vae_optim.zero_grad() loss = compute_vae_error( vae_encoder=self._modules.vae_encoder, @@ -90,7 +90,7 @@ def compute_imitator_grad( loss.backward() return {"loss": loss} - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_imitator(self, batch: TorchMiniBatch) -> dict[str, float]: loss = self._compute_imitator_grad(batch) self._modules.vae_optim.step() return {"vae_loss": float(loss["loss"].cpu().detach().numpy())} @@ -133,7 +133,7 @@ def update_actor_target(self) -> None: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: metrics = {} if grad_step < self._warmup_steps: diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index 58706147..e2c9a7ed 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Callable, Dict, Optional +from typing import Callable, Optional import torch import torch.nn.functional as F @@ -148,8 +148,8 @@ class DiscreteSACImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): _q_func_forwarder: DiscreteEnsembleQFunctionForwarder _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder _target_update_interval: int - _compute_critic_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] - _compute_actor_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] + _compute_critic_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]] + _compute_actor_grad: Callable[[TorchMiniBatch], dict[str, torch.Tensor]] def __init__( self, @@ -187,14 +187,14 @@ def __init__( def compute_critic_grad( self, batch: TorchMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: self._modules.critic_optim.zero_grad() q_tpn = self.compute_target(batch) loss = self.compute_critic_loss(batch, q_tpn) loss.backward() return {"loss": loss} - def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_critic(self, batch: TorchMiniBatch) -> dict[str, float]: loss = self._compute_critic_grad(batch) self._modules.critic_optim.step() return {"critic_loss": float(loss["loss"].cpu().detach().numpy())} @@ -235,13 +235,13 @@ def compute_critic_loss( def compute_actor_grad( self, batch: TorchMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: self._modules.actor_optim.zero_grad() loss = self.compute_actor_loss(batch) loss["loss"].backward() return loss - def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_actor(self, batch: TorchMiniBatch) -> dict[str, float]: # Q function should be inference mode for stability self._modules.q_funcs.eval() loss = self._compute_actor_grad(batch) @@ -250,7 +250,7 @@ def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: def compute_actor_loss( self, batch: TorchMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: with torch.no_grad(): q_t = self._q_func_forwarder.compute_expected_q( batch.observations, reduction="min" @@ -271,7 +271,7 @@ def compute_actor_loss( loss["loss"] = (probs * (entropy - q_t)).sum(dim=1).mean() return loss - def update_temp(self, dist: Categorical) -> Dict[str, torch.Tensor]: + def update_temp(self, dist: Categorical) -> dict[str, torch.Tensor]: assert self._modules.temp_optim assert self._modules.log_temp is not None self._modules.temp_optim.zero_grad() @@ -295,7 +295,7 @@ def update_temp(self, dist: Categorical) -> Dict[str, torch.Tensor]: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: metrics = {} metrics.update(self.update_critic(batch)) metrics.update(self.update_actor(batch)) diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 2dc6e513..73a43a0c 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -1,4 +1,3 @@ -from typing import Dict import torch @@ -64,7 +63,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: metrics = {} metrics.update(self.update_critic(batch)) diff --git a/d3rlpy/algos/qlearning/torch/utility.py b/d3rlpy/algos/qlearning/torch/utility.py index 5f71c81f..2a548d2b 100644 --- a/d3rlpy/algos/qlearning/torch/utility.py +++ b/d3rlpy/algos/qlearning/torch/utility.py @@ -1,4 +1,3 @@ -from typing import Tuple import torch from typing_extensions import Protocol @@ -59,7 +58,7 @@ def sample_q_values_with_policy( value_observations: TorchObservation, n_action_samples: int, detach_policy_output: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: dist = build_squashed_gaussian_distribution(policy(policy_observations)) # (batch, n, action), (batch, n) policy_actions, n_log_probs = dist.sample_n_with_log_prob(n_action_samples) diff --git a/d3rlpy/algos/transformer/base.py b/d3rlpy/algos/transformer/base.py index 7834f502..cb6831cf 100644 --- a/d3rlpy/algos/transformer/base.py +++ b/d3rlpy/algos/transformer/base.py @@ -3,8 +3,6 @@ from collections import defaultdict, deque from typing import ( Callable, - Deque, - Dict, Generic, Optional, Sequence, @@ -60,13 +58,13 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: @train_api def update( self, batch: TorchTrajectoryMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: return self.inner_update(batch, grad_step) @abstractmethod def inner_update( self, batch: TorchTrajectoryMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: raise NotImplementedError @@ -115,11 +113,11 @@ class StatefulTransformerWrapper(Generic[TTransformerImpl, TTransformerConfig]): _target_return: float _action_sampler: TransformerActionSampler _return_rest: float - _observations: Deque[Observation] - _actions: Deque[Union[NDArray, int]] - _rewards: Deque[float] - _returns_to_go: Deque[float] - _timesteps: Deque[int] + _observations: deque[Observation] + _actions: deque[Union[NDArray, int]] + _rewards: deque[float] + _returns_to_go: deque[float] + _timesteps: deque[int] _timestep: int def __init__( @@ -510,7 +508,7 @@ def fit( logger.close() - def update(self, batch: TrajectoryMiniBatch) -> Dict[str, float]: + def update(self, batch: TrajectoryMiniBatch) -> dict[str, float]: """Update parameters with mini-batch of data. Args: diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index a77dfda1..27fe3730 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Callable, Dict +from typing import Callable import torch import torch.nn.functional as F @@ -35,7 +35,7 @@ class DecisionTransformerModules(Modules): class DecisionTransformerImpl(TransformerAlgoImplBase): _modules: DecisionTransformerModules - _compute_grad: Callable[[TorchTrajectoryMiniBatch], Dict[str, torch.Tensor]] + _compute_grad: Callable[[TorchTrajectoryMiniBatch], dict[str, torch.Tensor]] def __init__( self, @@ -62,7 +62,7 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: def compute_grad( self, batch: TorchTrajectoryMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: self._modules.optim.zero_grad() loss = self.compute_loss(batch) loss.backward() @@ -70,7 +70,7 @@ def compute_grad( def inner_update( self, batch: TorchTrajectoryMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: loss = self._compute_grad(batch) self._modules.optim.step() return {"loss": float(loss["loss"].cpu().detach().numpy())} @@ -99,7 +99,7 @@ class DiscreteDecisionTransformerImpl(TransformerAlgoImplBase): _final_tokens: int _initial_learning_rate: float _tokens: int - _compute_grad: Callable[[TorchTrajectoryMiniBatch], Dict[str, torch.Tensor]] + _compute_grad: Callable[[TorchTrajectoryMiniBatch], dict[str, torch.Tensor]] def __init__( self, @@ -139,7 +139,7 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: def compute_grad( self, batch: TorchTrajectoryMiniBatch - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: self._modules.optim.zero_grad() loss = self.compute_loss(batch) loss.backward() @@ -147,7 +147,7 @@ def compute_grad( def inner_update( self, batch: TorchTrajectoryMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: loss = self._compute_grad(batch) self._modules.optim.step() diff --git a/d3rlpy/base.py b/d3rlpy/base.py index cb9541ee..956eb9dc 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -2,7 +2,7 @@ import io import pickle from abc import ABCMeta, abstractmethod -from typing import BinaryIO, Generic, Optional, Type, TypeVar, Union +from typing import BinaryIO, Generic, Optional, TypeVar, Union from gym.spaces import Box from gymnasium.spaces import Box as GymnasiumBox @@ -274,7 +274,7 @@ def load_model(self, fname: str) -> None: @classmethod def from_json( - cls: Type[Self], fname: str, device: DeviceArg = False + cls: type[Self], fname: str, device: DeviceArg = False ) -> Self: r"""Construct algorithm from params.json file. diff --git a/d3rlpy/cli.py b/d3rlpy/cli.py index 1a7fe308..ac79ea8b 100644 --- a/d3rlpy/cli.py +++ b/d3rlpy/cli.py @@ -5,7 +5,7 @@ import json import os import subprocess -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Optional, Sequence import click import gym @@ -50,7 +50,7 @@ def get_plt() -> "matplotlib.pyplot": def _compute_moving_average(values: np.ndarray, window: int) -> np.ndarray: assert values.ndim == 1 - results: List[float] = [] + results: list[float] = [] # average over past data for i in range(values.shape[0]): start = max(0, i - window) @@ -83,13 +83,13 @@ def stats(path: str) -> None: @click.option("--ylabel", default="value", help="Label on y-axis.") @click.option("--save", help="Flag to save the plot as an image.") def plot( - path: List[str], + path: list[str], window: int, show_steps: bool, show_max: bool, label: Optional[Sequence[str]], - xlim: Optional[Tuple[float, float]], - ylim: Optional[Tuple[float, float]], + xlim: Optional[tuple[float, float]], + ylim: Optional[tuple[float, float]], title: Optional[str], ylabel: str, save: str, @@ -230,7 +230,7 @@ def export(model_path: str, output_path: str) -> None: def _exec_to_create_env(code: str) -> gym.Env[Any, Any]: print(f"Executing '{code}'") - variables: Dict[str, Any] = {} + variables: dict[str, Any] = {} exec(code, globals(), variables) if "env" not in variables: raise RuntimeError("env must be defined in env_header.") @@ -349,13 +349,13 @@ def play( def _install_module( - name: List[str], upgrade: bool = False, check: bool = True + name: list[str], upgrade: bool = False, check: bool = True ) -> None: name = ["-U", *name] if upgrade else name subprocess.run(["pip3", "install", *name], check=check) -def _uninstall_module(name: List[str], check: bool = True) -> None: +def _uninstall_module(name: list[str], check: bool = True) -> None: subprocess.run(["pip3", "uninstall", "-y", *name], check=check) @@ -371,7 +371,6 @@ def _uninstall_module(name: List[str], check: bool = True) -> None: @cli.command(short_help="Install additional packages.") @click.argument("name") def install(name: str) -> None: - def print_available_options() -> None: print("List of available options.") for name, description in INSTALL_OPTIONS.items(): diff --git a/d3rlpy/dataclass_utils.py b/d3rlpy/dataclass_utils.py index 79499a7a..70d7ea17 100644 --- a/d3rlpy/dataclass_utils.py +++ b/d3rlpy/dataclass_utils.py @@ -1,21 +1,21 @@ import dataclasses -from typing import Any, Dict +from typing import Any import torch __all__ = ["asdict_without_copy", "asdict_as_float"] -def asdict_without_copy(obj: Any) -> Dict[str, Any]: +def asdict_without_copy(obj: Any) -> dict[str, Any]: assert dataclasses.is_dataclass(obj) fields = dataclasses.fields(obj) return {field.name: getattr(obj, field.name) for field in fields} -def asdict_as_float(obj: Any) -> Dict[str, float]: +def asdict_as_float(obj: Any) -> dict[str, float]: assert dataclasses.is_dataclass(obj) fields = dataclasses.fields(obj) - ret: Dict[str, float] = {} + ret: dict[str, float] = {} for field in fields: value = getattr(obj, field.name) if isinstance(value, torch.Tensor): diff --git a/d3rlpy/dataset/buffers.py b/d3rlpy/dataset/buffers.py index d2cde6e1..36bacd77 100644 --- a/d3rlpy/dataset/buffers.py +++ b/d3rlpy/dataset/buffers.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Deque, List, Sequence, Tuple +from typing import Sequence from typing_extensions import Protocol @@ -38,15 +38,15 @@ def transition_count(self) -> int: """ raise NotImplementedError - def __getitem__(self, index: int) -> Tuple[EpisodeBase, int]: + def __getitem__(self, index: int) -> tuple[EpisodeBase, int]: raise NotImplementedError class InfiniteBuffer(BufferProtocol): r"""Buffer with unlimited capacity.""" - _transitions: List[Tuple[EpisodeBase, int]] - _episodes: List[EpisodeBase] + _transitions: list[tuple[EpisodeBase, int]] + _episodes: list[EpisodeBase] def __init__(self) -> None: self._transitions = [] @@ -69,7 +69,7 @@ def transition_count(self) -> int: def __len__(self) -> int: return len(self._transitions) - def __getitem__(self, index: int) -> Tuple[EpisodeBase, int]: + def __getitem__(self, index: int) -> tuple[EpisodeBase, int]: return self._transitions[index] @@ -80,8 +80,8 @@ class FIFOBuffer(BufferProtocol): limit (int): buffer capacity. """ - _transitions: Deque[Tuple[EpisodeBase, int]] - _episodes: List[EpisodeBase] + _transitions: deque[tuple[EpisodeBase, int]] + _episodes: list[EpisodeBase] _limit: int def __init__(self, limit: int): @@ -109,5 +109,5 @@ def transition_count(self) -> int: def __len__(self) -> int: return len(self._transitions) - def __getitem__(self, index: int) -> Tuple[EpisodeBase, int]: + def __getitem__(self, index: int) -> tuple[EpisodeBase, int]: return self._transitions[index] diff --git a/d3rlpy/dataset/components.py b/d3rlpy/dataset/components.py index 7603daa0..a41a2198 100644 --- a/d3rlpy/dataset/components.py +++ b/d3rlpy/dataset/components.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, Sequence +from typing import Any, Sequence import numpy as np from typing_extensions import Protocol @@ -282,7 +282,7 @@ def compute_return(self) -> float: """ raise NotImplementedError - def serialize(self) -> Dict[str, Any]: + def serialize(self) -> dict[str, Any]: r"""Returns serized episode data. Returns: @@ -291,7 +291,7 @@ def serialize(self) -> Dict[str, Any]: raise NotImplementedError @classmethod - def deserialize(cls, serializedData: Dict[str, Any]) -> "EpisodeBase": + def deserialize(cls, serializedData: dict[str, Any]) -> "EpisodeBase": r"""Constructs episode from serialized data. This is an inverse operation of ``serialize`` method. @@ -362,7 +362,7 @@ def size(self) -> int: def compute_return(self) -> float: return float(np.sum(self.rewards)) - def serialize(self) -> Dict[str, Any]: + def serialize(self) -> dict[str, Any]: return { "observations": self.observations, "actions": self.actions, @@ -371,7 +371,7 @@ def serialize(self) -> Dict[str, Any]: } @classmethod - def deserialize(cls, serializedData: Dict[str, Any]) -> "Episode": + def deserialize(cls, serializedData: dict[str, Any]) -> "Episode": return cls( observations=serializedData["observations"], actions=serializedData["actions"], diff --git a/d3rlpy/dataset/io.py b/d3rlpy/dataset/io.py index b0d8e230..345989be 100644 --- a/d3rlpy/dataset/io.py +++ b/d3rlpy/dataset/io.py @@ -1,4 +1,4 @@ -from typing import BinaryIO, Sequence, Type, TypeVar, cast +from typing import BinaryIO, Sequence, TypeVar, cast import h5py import numpy as np @@ -39,7 +39,7 @@ def dump(episodes: Sequence[EpisodeBase], f: BinaryIO) -> None: _TEpisode = TypeVar("_TEpisode", bound=EpisodeBase) -def load(episode_cls: Type[_TEpisode], f: BinaryIO) -> Sequence[_TEpisode]: +def load(episode_cls: type[_TEpisode], f: BinaryIO) -> Sequence[_TEpisode]: r"""Constructs episodes from file-like object. Args: diff --git a/d3rlpy/dataset/replay_buffer.py b/d3rlpy/dataset/replay_buffer.py index 0b70f5b6..2ebf945b 100644 --- a/d3rlpy/dataset/replay_buffer.py +++ b/d3rlpy/dataset/replay_buffer.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import BinaryIO, List, Optional, Sequence, Type, Union +from typing import BinaryIO, Optional, Sequence, Union import numpy as np @@ -172,7 +172,7 @@ def load( cls, f: BinaryIO, buffer: BufferProtocol, - episode_cls: Type[EpisodeBase] = Episode, + episode_cls: type[EpisodeBase] = Episode, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, @@ -336,7 +336,7 @@ class ReplayBuffer(ReplayBufferBase): _transition_picker: TransitionPickerProtocol _trajectory_slicer: TrajectorySlicerProtocol _writer: ExperienceWriter - _episodes: List[EpisodeBase] + _episodes: list[EpisodeBase] _dataset_info: DatasetInfo def __init__( @@ -515,7 +515,7 @@ def load( cls, f: BinaryIO, buffer: BufferProtocol, - episode_cls: Type[EpisodeBase] = Episode, + episode_cls: type[EpisodeBase] = Episode, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, @@ -693,7 +693,7 @@ def load( cls, f: BinaryIO, buffer: BufferProtocol, - episode_cls: Type[EpisodeBase] = Episode, + episode_cls: type[EpisodeBase] = Episode, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, diff --git a/d3rlpy/dataset/utils.py b/d3rlpy/dataset/utils.py index 02d51b18..cf8888e7 100644 --- a/d3rlpy/dataset/utils.py +++ b/d3rlpy/dataset/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence, Type, TypeVar, Union, overload +from typing import Any, Sequence, TypeVar, Union, overload import numpy as np import numpy.typing as npt @@ -315,18 +315,18 @@ def check_non_1d_array(array: Union[NDArray, Sequence[NDArray]]) -> bool: @overload def cast_recursively( - array: NDArray, dtype: Type[_TDType] + array: NDArray, dtype: type[_TDType] ) -> npt.NDArray[_TDType]: ... @overload def cast_recursively( - array: Sequence[NDArray], dtype: Type[_TDType] + array: Sequence[NDArray], dtype: type[_TDType] ) -> Sequence[npt.NDArray[_TDType]]: ... def cast_recursively( - array: Union[NDArray, Sequence[NDArray]], dtype: Type[_TDType] + array: Union[NDArray, Sequence[NDArray]], dtype: type[_TDType] ) -> Union[npt.NDArray[_TDType], Sequence[npt.NDArray[_TDType]]]: if isinstance(array, (list, tuple)): return [array[i].astype(dtype) for i in range(len(array))] diff --git a/d3rlpy/dataset/writers.py b/d3rlpy/dataset/writers.py index 7c345073..5b4a8621 100644 --- a/d3rlpy/dataset/writers.py +++ b/d3rlpy/dataset/writers.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Sequence, Union +from typing import Any, Sequence, Union import numpy as np from typing_extensions import Protocol @@ -219,7 +219,7 @@ def reward_signature(self) -> Signature: def compute_return(self) -> float: return float(np.sum(self.rewards[: self._cursor])) - def serialize(self) -> Dict[str, Any]: + def serialize(self) -> dict[str, Any]: return { "observations": self.observations, "actions": self.actions, @@ -228,7 +228,7 @@ def serialize(self) -> Dict[str, Any]: } @classmethod - def deserialize(cls, serializedData: Dict[str, Any]) -> "EpisodeBase": + def deserialize(cls, serializedData: dict[str, Any]) -> "EpisodeBase": raise NotImplementedError("_ActiveEpisode cannot be deserialized.") def __len__(self) -> int: diff --git a/d3rlpy/datasets.py b/d3rlpy/datasets.py index 937fe141..2a1207e4 100644 --- a/d3rlpy/datasets.py +++ b/d3rlpy/datasets.py @@ -4,7 +4,7 @@ import os import random import re -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from urllib import request import gym @@ -52,9 +52,13 @@ DATA_DIRECTORY = "d3rlpy_data" DROPBOX_URL = "https://www.dropbox.com/s" CARTPOLE_URL = f"{DROPBOX_URL}/uep0lzlhxpi79pd/cartpole_v1.1.0.h5?dl=1" -CARTPOLE_RANDOM_URL = f"{DROPBOX_URL}/4lgai7tgj84cbov/cartpole_random_v1.1.0.h5?dl=1" # pylint: disable=line-too-long +CARTPOLE_RANDOM_URL = ( + f"{DROPBOX_URL}/4lgai7tgj84cbov/cartpole_random_v1.1.0.h5?dl=1" # pylint: disable=line-too-long +) PENDULUM_URL = f"{DROPBOX_URL}/ukkucouzys0jkfs/pendulum_v1.1.0.h5?dl=1" -PENDULUM_RANDOM_URL = f"{DROPBOX_URL}/hhbq9i6ako24kzz/pendulum_random_v1.1.0.h5?dl=1" # pylint: disable=line-too-long +PENDULUM_RANDOM_URL = ( + f"{DROPBOX_URL}/hhbq9i6ako24kzz/pendulum_random_v1.1.0.h5?dl=1" # pylint: disable=line-too-long +) def get_cartpole( @@ -62,7 +66,7 @@ def get_cartpole( transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, -) -> Tuple[ReplayBuffer, gym.Env[NDArray, int]]: +) -> tuple[ReplayBuffer, gym.Env[NDArray, int]]: """Returns cartpole dataset and environment. The dataset is automatically downloaded to ``d3rlpy_data/cartpole.h5`` if @@ -116,7 +120,7 @@ def get_pendulum( transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, -) -> Tuple[ReplayBuffer, gym.Env[NDArray, NDArray]]: +) -> tuple[ReplayBuffer, gym.Env[NDArray, NDArray]]: """Returns pendulum dataset and environment. The dataset is automatically downloaded to ``d3rlpy_data/pendulum.h5`` if @@ -193,7 +197,7 @@ def get_atari( sticky_action: bool = True, pre_stack: bool = False, render_mode: Optional[str] = None, -) -> Tuple[ReplayBuffer, gym.Env[NDArray, int]]: +) -> tuple[ReplayBuffer, gym.Env[NDArray, int]]: """Returns atari dataset and envrironment. The dataset is provided through d4rl-atari. See more details including @@ -221,7 +225,7 @@ def get_atari( tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. """ try: - import d4rl_atari # type: ignore + import d4rl_atari # type: ignore # noqa env = gym.make( env_name, @@ -273,7 +277,7 @@ def get_atari_transitions( sticky_action: bool = True, pre_stack: bool = False, render_mode: Optional[str] = None, -) -> Tuple[ReplayBuffer, gym.Env[NDArray, int]]: +) -> tuple[ReplayBuffer, gym.Env[NDArray, int]]: """Returns atari dataset as a list of Transition objects and envrironment. The dataset is provided through d4rl-atari. @@ -307,7 +311,7 @@ def get_atari_transitions( environment. """ try: - import d4rl_atari + import d4rl_atari # noqa # each epoch consists of 1M steps num_transitions_per_epoch = int(1000000 * fraction) @@ -390,7 +394,7 @@ def get_d4rl( trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, max_episode_steps: int = 1000, -) -> Tuple[ReplayBuffer, gym.Env[NDArray, NDArray]]: +) -> tuple[ReplayBuffer, gym.Env[NDArray, NDArray]]: """Returns d4rl dataset and envrironment. The dataset is provided through d4rl. @@ -417,14 +421,14 @@ def get_d4rl( tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. """ try: - import d4rl + import d4rl # noqa from d4rl.locomotion.wrappers import NormalizedBoxEnv from d4rl.utils.wrappers import ( NormalizedBoxEnv as NormalizedBoxEnvFromUtils, ) env = gym.make(env_name) - raw_dataset: Dict[str, NDArray] = env.get_dataset() # type: ignore + raw_dataset: dict[str, NDArray] = env.get_dataset() # type: ignore observations = raw_dataset["observations"] actions = raw_dataset["actions"] @@ -470,7 +474,7 @@ def get_minari( trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, tuple_observation: bool = False, -) -> Tuple[ReplayBuffer, gymnasium.Env[Any, Any]]: +) -> tuple[ReplayBuffer, gymnasium.Env[Any, Any]]: """Returns minari dataset and envrironment. The dataset is provided through minari. @@ -654,7 +658,7 @@ def get_dataset( transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, -) -> Tuple[ReplayBuffer, gym.Env[Any, Any]]: +) -> tuple[ReplayBuffer, gym.Env[Any, Any]]: """Returns dataset and envrironment by guessing from name. This function returns dataset by matching name with the following datasets. diff --git a/d3rlpy/envs/wrappers.py b/d3rlpy/envs/wrappers.py index 407811eb..5abe7928 100644 --- a/d3rlpy/envs/wrappers.py +++ b/d3rlpy/envs/wrappers.py @@ -1,12 +1,9 @@ from collections import deque from typing import ( Any, - Deque, - Dict, Optional, Sequence, SupportsFloat, - Tuple, TypeVar, Union, ) @@ -79,7 +76,7 @@ def __init__(self, env: gym.Env[_ObsType, _ActType]): def step( self, action: _ActType - ) -> Tuple[_ObsType, float, bool, bool, Dict[str, Any]]: + ) -> tuple[_ObsType, float, bool, bool, dict[str, Any]]: observation, reward, terminal, truncated, info = self.env.step(action) # make channel first observation if observation.ndim == 3: @@ -89,7 +86,7 @@ def step( assert observation_T.shape == self.observation_space.shape return observation_T, reward, terminal, truncated, info # type: ignore - def reset(self, **kwargs: Any) -> Tuple[_ObsType, Dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[_ObsType, dict[str, Any]]: observation, info = self.env.reset(**kwargs) # make channel first observation if observation.ndim == 3: @@ -112,7 +109,7 @@ class FrameStack(gym.Wrapper[NDArray, _ActType]): """ _num_stack: int - _frames: Deque[NDArray] + _frames: deque[NDArray] def __init__(self, env: gym.Env[NDArray, _ActType], num_stack: int): super().__init__(env) @@ -140,12 +137,12 @@ def observation(self, observation: Any) -> NDArray: def step( self, action: _ActType - ) -> Tuple[NDArray, float, bool, bool, Dict[str, Any]]: + ) -> tuple[NDArray, float, bool, bool, dict[str, Any]]: observation, reward, terminated, truncated, info = self.env.step(action) self._frames.append(observation) return self.observation(None), reward, terminated, truncated, info - def reset(self, **kwargs: Any) -> Tuple[NDArray, Dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[NDArray, dict[str, Any]]: obs, info = self.env.reset(**kwargs) for _ in range(self._num_stack - 1): self._frames.append(np.zeros_like(obs)) @@ -256,7 +253,7 @@ def __init__( def step( self, action: int - ) -> Tuple[NDArray, float, bool, bool, Dict[str, Any]]: + ) -> tuple[NDArray, float, bool, bool, dict[str, Any]]: R = 0.0 for t in range(self.frame_skip): @@ -284,7 +281,7 @@ def step( return self._get_obs(), R, done, truncated, info - def reset(self, **kwargs: Any) -> Tuple[NDArray, Dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[NDArray, dict[str, Any]]: # this condition is not included in the original code if self.game_over: _, info = self.env.reset(**kwargs) @@ -364,16 +361,16 @@ def _get_keys_from_observation_space( return sorted(list(observation_space.keys())) -def _flat_dict_observation(observation: Dict[str, NDArray]) -> NDArray: +def _flat_dict_observation(observation: dict[str, NDArray]) -> NDArray: sorted_keys = sorted(list(observation.keys())) return np.concatenate([observation[key] for key in sorted_keys]) class GoalConcatWrapper( gymnasium.Wrapper[ - Union[NDArray, Tuple[NDArray, NDArray]], + Union[NDArray, tuple[NDArray, NDArray]], _ActType, - Dict[str, NDArray], + dict[str, NDArray], _ActType, ] ): @@ -397,7 +394,7 @@ class GoalConcatWrapper( def __init__( self, - env: gymnasium.Env[Dict[str, NDArray], _ActType], + env: gymnasium.Env[dict[str, NDArray], _ActType], observation_key: str = "observation", goal_key: str = "desired_goal", tuple_observation: bool = False, @@ -452,18 +449,20 @@ def __init__( dtype=observation_space.dtype, # type: ignore ) - def step(self, action: _ActType) -> Tuple[ - Union[NDArray, Tuple[NDArray, NDArray]], + def step( + self, action: _ActType + ) -> tuple[ + Union[NDArray, tuple[NDArray, NDArray]], SupportsFloat, bool, bool, - Dict[str, Any], + dict[str, Any], ]: obs, rew, terminal, truncate, info = self.env.step(action) goal_obs = obs[self._goal_key] if isinstance(goal_obs, dict): goal_obs = _flat_dict_observation(goal_obs) - concat_obs: Union[NDArray, Tuple[NDArray, NDArray]] + concat_obs: Union[NDArray, tuple[NDArray, NDArray]] if self._tuple_observation: concat_obs = (obs[self._observation_key], goal_obs) else: @@ -474,13 +473,13 @@ def reset( self, *, seed: Optional[int] = None, - options: Optional[Dict[str, Any]] = None, - ) -> Tuple[Union[NDArray, Tuple[NDArray, NDArray]], Dict[str, Any]]: + options: Optional[dict[str, Any]] = None, + ) -> tuple[Union[NDArray, tuple[NDArray, NDArray]], dict[str, Any]]: obs, info = self.env.reset(seed=seed, options=options) goal_obs = obs[self._goal_key] if isinstance(goal_obs, dict): goal_obs = _flat_dict_observation(goal_obs) - concat_obs: Union[NDArray, Tuple[NDArray, NDArray]] + concat_obs: Union[NDArray, tuple[NDArray, NDArray]] if self._tuple_observation: concat_obs = (obs[self._observation_key], goal_obs) else: diff --git a/d3rlpy/itertools.py b/d3rlpy/itertools.py index 79bb892c..9cabdf3b 100644 --- a/d3rlpy/itertools.py +++ b/d3rlpy/itertools.py @@ -1,17 +1,17 @@ -from typing import Iterable, Iterator, Tuple, TypeVar +from typing import Iterable, Iterator, TypeVar __all__ = ["last_flag", "first_flag"] T = TypeVar("T") -def last_flag(iterator: Iterable[T]) -> Iterator[Tuple[bool, T]]: +def last_flag(iterator: Iterable[T]) -> Iterator[tuple[bool, T]]: items = list(iterator) for i, item in enumerate(items): yield i == len(items) - 1, item -def first_flag(iterator: Iterable[T]) -> Iterator[Tuple[bool, T]]: +def first_flag(iterator: Iterable[T]) -> Iterator[tuple[bool, T]]: items = list(iterator) for i, item in enumerate(items): yield i == 0, item diff --git a/d3rlpy/logging/file_adapter.py b/d3rlpy/logging/file_adapter.py index 637402c1..9f8cb6ff 100644 --- a/d3rlpy/logging/file_adapter.py +++ b/d3rlpy/logging/file_adapter.py @@ -1,7 +1,7 @@ import json import os from enum import Enum, IntEnum -from typing import Any, Dict +from typing import Any import numpy as np @@ -52,7 +52,7 @@ def __init__(self, algo: AlgProtocol, logdir: str): os.makedirs(self._logdir) LOG.info(f"Directory is created at {self._logdir}") - def write_params(self, params: Dict[str, Any]) -> None: + def write_params(self, params: dict[str, Any]) -> None: # save dictionary as json file params_path = os.path.join(self._logdir, "params.json") with open(params_path, "w") as f: diff --git a/d3rlpy/logging/logger.py b/d3rlpy/logging/logger.py index c5c42d5d..6811a8b3 100644 --- a/d3rlpy/logging/logger.py +++ b/d3rlpy/logging/logger.py @@ -2,7 +2,7 @@ from collections import defaultdict from contextlib import contextmanager from datetime import datetime -from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Tuple +from typing import Any, Iterator, Optional import structlog from torch import nn @@ -43,8 +43,8 @@ def save(self, fname: str) -> None: ... class ModuleProtocol(Protocol): - def get_torch_modules(self) -> Dict[str, nn.Module]: ... - def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]: ... + def get_torch_modules(self) -> dict[str, nn.Module]: ... + def get_gradients(self) -> Iterator[tuple[str, Float32NDArray]]: ... class ImplProtocol(Protocol): @@ -62,7 +62,7 @@ def impl(self) -> Optional[ImplProtocol]: class LoggerAdapter(Protocol): r"""Interface of LoggerAdapter.""" - def write_params(self, params: Dict[str, Any]) -> None: + def write_params(self, params: dict[str, Any]) -> None: r"""Writes hyperparameters. Args: @@ -145,7 +145,7 @@ class D3RLPyLogger: _algo: AlgProtocol _adapter: LoggerAdapter _experiment_name: str - _metrics_buffer: DefaultDict[str, List[float]] + _metrics_buffer: defaultdict[str, list[float]] def __init__( self, @@ -166,14 +166,14 @@ def __init__( ) self._metrics_buffer = defaultdict(list) - def add_params(self, params: Dict[str, Any]) -> None: + def add_params(self, params: dict[str, Any]) -> None: self._adapter.write_params(params) LOG.info("Parameters", params=params) def add_metric(self, name: str, value: float) -> None: self._metrics_buffer[name].append(value) - def commit(self, epoch: int, step: int) -> Dict[str, float]: + def commit(self, epoch: int, step: int) -> dict[str, float]: self._adapter.before_write_metric(epoch, step) metrics = {} diff --git a/d3rlpy/logging/noop_adapter.py b/d3rlpy/logging/noop_adapter.py index 18bae16f..74089faf 100644 --- a/d3rlpy/logging/noop_adapter.py +++ b/d3rlpy/logging/noop_adapter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from .logger import ( AlgProtocol, @@ -17,7 +17,7 @@ class NoopAdapter(LoggerAdapter): are not allowed to write things to disks. """ - def write_params(self, params: Dict[str, Any]) -> None: + def write_params(self, params: dict[str, Any]) -> None: pass def before_write_metric(self, epoch: int, step: int) -> None: diff --git a/d3rlpy/logging/tensorboard_adapter.py b/d3rlpy/logging/tensorboard_adapter.py index 3112830f..7ad1a7b7 100644 --- a/d3rlpy/logging/tensorboard_adapter.py +++ b/d3rlpy/logging/tensorboard_adapter.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict +from typing import Any import numpy as np @@ -29,8 +29,8 @@ class TensorboardAdapter(LoggerAdapter): _algo: AlgProtocol _experiment_name: str - _params: Dict[str, Any] - _metrics: Dict[str, float] + _params: dict[str, Any] + _metrics: dict[str, float] def __init__(self, algo: AlgProtocol, root_dir: str, experiment_name: str): try: @@ -44,7 +44,7 @@ def __init__(self, algo: AlgProtocol, root_dir: str, experiment_name: str): self._writer = SummaryWriter(logdir=logdir) self._metrics = {} - def write_params(self, params: Dict[str, Any]) -> None: + def write_params(self, params: dict[str, Any]) -> None: # remove non-scaler values for HParams self._params = {k: v for k, v in params.items() if np.isscalar(v)} diff --git a/d3rlpy/logging/utils.py b/d3rlpy/logging/utils.py index cdb47e26..249fe6e9 100644 --- a/d3rlpy/logging/utils.py +++ b/d3rlpy/logging/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Sequence +from typing import Any, Sequence from .logger import ( AlgProtocol, @@ -23,7 +23,7 @@ class CombineAdapter(LoggerAdapter): def __init__(self, adapters: Sequence[LoggerAdapter]): self._adapters = adapters - def write_params(self, params: Dict[str, Any]) -> None: + def write_params(self, params: dict[str, Any]) -> None: for adapter in self._adapters: adapter.write_params(params) diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py index cf4ba6a4..c8ba0fb5 100644 --- a/d3rlpy/logging/wandb_adapter.py +++ b/d3rlpy/logging/wandb_adapter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional from .logger import ( AlgProtocol, @@ -42,7 +42,7 @@ def __init__( ) self._is_model_watched = False - def write_params(self, params: Dict[str, Any]) -> None: + def write_params(self, params: dict[str, Any]) -> None: """Writes hyperparameters to WandB config.""" self.run.config.update(params) diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index 3238b6ed..74564013 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple, cast +from typing import Sequence, cast import torch from torch import nn @@ -52,7 +52,7 @@ def create_discrete_q_function( device: str, enable_ddp: bool, n_ensembles: int = 1, -) -> Tuple[nn.ModuleList, DiscreteEnsembleQFunctionForwarder]: +) -> tuple[nn.ModuleList, DiscreteEnsembleQFunctionForwarder]: if q_func_factory.share_encoder: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder) @@ -90,7 +90,7 @@ def create_continuous_q_function( device: str, enable_ddp: bool, n_ensembles: int = 1, -) -> Tuple[nn.ModuleList, ContinuousEnsembleQFunctionForwarder]: +) -> tuple[nn.ModuleList, ContinuousEnsembleQFunctionForwarder]: if q_func_factory.share_encoder: encoder = encoder_factory.create_with_action( observation_shape, action_size diff --git a/d3rlpy/models/encoders.py b/d3rlpy/models/encoders.py index 2b6fd906..b26e549e 100644 --- a/d3rlpy/models/encoders.py +++ b/d3rlpy/models/encoders.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import Optional, Union from ..dataset import cast_flat_shape from ..serializable_config import DynamicConfig, generate_config_registration @@ -75,7 +75,7 @@ class PixelEncoderFactory(EncoderFactory): last_activation (str): Activation function name for the last layer. """ - filters: List[List[int]] = field( + filters: list[list[int]] = field( default_factory=lambda: [[32, 8, 4], [64, 4, 2], [64, 3, 1]] ) feature_size: int = 512 @@ -149,7 +149,7 @@ class VectorEncoderFactory(EncoderFactory): last_activation (str): Activation function name for the last layer. """ - hidden_units: List[int] = field(default_factory=lambda: [256, 256]) + hidden_units: list[int] = field(default_factory=lambda: [256, 256]) activation: str = "relu" use_batch_norm: bool = False use_layer_norm: bool = False diff --git a/d3rlpy/models/q_functions.py b/d3rlpy/models/q_functions.py index 55681627..2f69f811 100644 --- a/d3rlpy/models/q_functions.py +++ b/d3rlpy/models/q_functions.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Tuple from ..serializable_config import DynamicConfig, generate_config_registration from .torch import ( @@ -38,7 +37,7 @@ class QFunctionFactory(DynamicConfig): def create_discrete( self, encoder: Encoder, hidden_size: int, action_size: int - ) -> Tuple[DiscreteQFunction, DiscreteQFunctionForwarder]: + ) -> tuple[DiscreteQFunction, DiscreteQFunctionForwarder]: """Returns PyTorch's Q function module. Args: @@ -54,7 +53,7 @@ def create_discrete( def create_continuous( self, encoder: EncoderWithAction, hidden_size: int - ) -> Tuple[ContinuousQFunction, ContinuousQFunctionForwarder]: + ) -> tuple[ContinuousQFunction, ContinuousQFunctionForwarder]: """Returns PyTorch's Q function module. Args: @@ -98,7 +97,7 @@ def create_discrete( encoder: Encoder, hidden_size: int, action_size: int, - ) -> Tuple[DiscreteMeanQFunction, DiscreteMeanQFunctionForwarder]: + ) -> tuple[DiscreteMeanQFunction, DiscreteMeanQFunctionForwarder]: q_func = DiscreteMeanQFunction(encoder, hidden_size, action_size) forwarder = DiscreteMeanQFunctionForwarder(q_func, action_size) return q_func, forwarder @@ -107,7 +106,7 @@ def create_continuous( self, encoder: EncoderWithAction, hidden_size: int, - ) -> Tuple[ContinuousMeanQFunction, ContinuousMeanQFunctionForwarder]: + ) -> tuple[ContinuousMeanQFunction, ContinuousMeanQFunctionForwarder]: q_func = ContinuousMeanQFunction(encoder, hidden_size) forwarder = ContinuousMeanQFunctionForwarder(q_func) return q_func, forwarder @@ -134,7 +133,7 @@ class QRQFunctionFactory(QFunctionFactory): def create_discrete( self, encoder: Encoder, hidden_size: int, action_size: int - ) -> Tuple[DiscreteQRQFunction, DiscreteQRQFunctionForwarder]: + ) -> tuple[DiscreteQRQFunction, DiscreteQRQFunctionForwarder]: q_func = DiscreteQRQFunction( encoder=encoder, hidden_size=hidden_size, @@ -148,7 +147,7 @@ def create_continuous( self, encoder: EncoderWithAction, hidden_size: int, - ) -> Tuple[ContinuousQRQFunction, ContinuousQRQFunctionForwarder]: + ) -> tuple[ContinuousQRQFunction, ContinuousQRQFunctionForwarder]: q_func = ContinuousQRQFunction( encoder=encoder, hidden_size=hidden_size, @@ -186,7 +185,7 @@ def create_discrete( encoder: Encoder, hidden_size: int, action_size: int, - ) -> Tuple[DiscreteIQNQFunction, DiscreteIQNQFunctionForwarder]: + ) -> tuple[DiscreteIQNQFunction, DiscreteIQNQFunctionForwarder]: q_func = DiscreteIQNQFunction( encoder=encoder, hidden_size=hidden_size, @@ -200,7 +199,7 @@ def create_discrete( def create_continuous( self, encoder: EncoderWithAction, hidden_size: int - ) -> Tuple[ContinuousIQNQFunction, ContinuousIQNQFunctionForwarder]: + ) -> tuple[ContinuousIQNQFunction, ContinuousIQNQFunctionForwarder]: q_func = ContinuousIQNQFunction( encoder=encoder, hidden_size=hidden_size, diff --git a/d3rlpy/models/torch/distributions.py b/d3rlpy/models/torch/distributions.py index e930cb42..9cce61fe 100644 --- a/d3rlpy/models/torch/distributions.py +++ b/d3rlpy/models/torch/distributions.py @@ -1,6 +1,5 @@ import math from abc import ABCMeta, abstractmethod -from typing import Tuple import torch import torch.nn.functional as F @@ -19,7 +18,7 @@ def sample(self) -> torch.Tensor: pass @abstractmethod - def sample_with_log_prob(self) -> Tuple[torch.Tensor, torch.Tensor]: + def sample_with_log_prob(self) -> tuple[torch.Tensor, torch.Tensor]: pass @abstractmethod @@ -33,7 +32,7 @@ def onnx_safe_sample_n(self, n: int) -> torch.Tensor: @abstractmethod def sample_n_with_log_prob( self, n: int - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: pass @abstractmethod @@ -61,7 +60,7 @@ def __init__( def sample(self) -> torch.Tensor: return self._dist.rsample().clamp(-1.0, 1.0) - def sample_with_log_prob(self) -> Tuple[torch.Tensor, torch.Tensor]: + def sample_with_log_prob(self) -> tuple[torch.Tensor, torch.Tensor]: y = self.sample() return y, self.log_prob(y) @@ -87,7 +86,7 @@ def onnx_safe_sample_n(self, n: int) -> torch.Tensor: def sample_n_with_log_prob( self, n: int - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: x = self.sample_n(n) return x, self.log_prob(x.transpose(0, 1)).transpose(0, 1) @@ -95,7 +94,7 @@ def sample_n_without_squash(self, n: int) -> torch.Tensor: assert self._raw_loc is not None return Normal(self._raw_loc, self._std).rsample((n,)).transpose(0, 1) - def mean_with_log_prob(self) -> Tuple[torch.Tensor, torch.Tensor]: + def mean_with_log_prob(self) -> tuple[torch.Tensor, torch.Tensor]: return self._mean, self.log_prob(self._mean) def log_prob(self, y: torch.Tensor) -> torch.Tensor: @@ -123,7 +122,7 @@ def __init__(self, loc: torch.Tensor, std: torch.Tensor): def sample(self) -> torch.Tensor: return torch.tanh(self._dist.rsample()) - def sample_with_log_prob(self) -> Tuple[torch.Tensor, torch.Tensor]: + def sample_with_log_prob(self) -> tuple[torch.Tensor, torch.Tensor]: raw_y = self._dist.rsample() log_prob = self._log_prob_from_raw_y(raw_y) return torch.tanh(raw_y), log_prob @@ -149,7 +148,7 @@ def onnx_safe_sample_n(self, n: int) -> torch.Tensor: def sample_n_with_log_prob( self, n: int - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: raw_y = self._dist.rsample((n,)) log_prob = self._log_prob_from_raw_y(raw_y) return torch.tanh(raw_y).transpose(0, 1), log_prob.transpose(0, 1) @@ -157,7 +156,7 @@ def sample_n_with_log_prob( def sample_n_without_squash(self, n: int) -> torch.Tensor: return self._dist.rsample((n,)).transpose(0, 1) - def mean_with_log_prob(self) -> Tuple[torch.Tensor, torch.Tensor]: + def mean_with_log_prob(self) -> tuple[torch.Tensor, torch.Tensor]: return torch.tanh(self._mean), self._log_prob_from_raw_y(self._mean) def log_prob(self, y: torch.Tensor) -> torch.Tensor: diff --git a/d3rlpy/models/torch/encoders.py b/d3rlpy/models/torch/encoders.py index 9929e663..e57be61e 100644 --- a/d3rlpy/models/torch/encoders.py +++ b/d3rlpy/models/torch/encoders.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import List, Optional, Sequence +from typing import Optional, Sequence import torch import torch.nn.functional as F @@ -48,7 +48,7 @@ class PixelEncoder(Encoder): def __init__( self, observation_shape: Sequence[int], - filters: Optional[List[List[int]]] = None, + filters: Optional[list[list[int]]] = None, feature_size: int = 512, use_batch_norm: bool = False, dropout_rate: Optional[float] = False, @@ -90,7 +90,7 @@ def __init__( cnn_output_size = self._cnn_layers(x).view(1, -1).shape[1] # last dense layer - layers: List[nn.Module] = [] + layers: list[nn.Module] = [] layers.append(nn.Linear(cnn_output_size, feature_size)) if not exclude_last_activation: layers.append(last_activation if last_activation else activation) @@ -117,7 +117,7 @@ def __init__( self, observation_shape: Sequence[int], action_size: int, - filters: Optional[List[List[int]]] = None, + filters: Optional[list[list[int]]] = None, feature_size: int = 512, use_batch_norm: bool = False, dropout_rate: Optional[float] = False, @@ -162,7 +162,7 @@ def __init__( cnn_output_size = self._cnn_layers(x).view(1, -1).shape[1] # last dense layer - layers: List[nn.Module] = [] + layers: list[nn.Module] = [] layers.append(nn.Linear(cnn_output_size + action_size, feature_size)) if not exclude_last_activation: layers.append(last_activation if last_activation else activation) diff --git a/d3rlpy/models/torch/q_functions/ensemble_q_function.py b/d3rlpy/models/torch/q_functions/ensemble_q_function.py index 9b444265..3bf1b92f 100644 --- a/d3rlpy/models/torch/q_functions/ensemble_q_function.py +++ b/d3rlpy/models/torch/q_functions/ensemble_q_function.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union import torch @@ -118,7 +118,7 @@ def compute_ensemble_q_function_target( lam: float = 0.75, ) -> torch.Tensor: batch_size = get_batch_size(x) - values_list: List[torch.Tensor] = [] + values_list: list[torch.Tensor] = [] for forwarder in forwarders: if isinstance(forwarder, ContinuousQFunctionForwarder): assert action is not None @@ -289,7 +289,7 @@ def compute_max_with_n_actions_and_indices( actions: torch.Tensor, forwarder: ContinuousEnsembleQFunctionForwarder, lam: float, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Returns weighted target value from sampled actions. This calculation is proposed in BCQ paper for the first time. diff --git a/d3rlpy/models/torch/transformers.py b/d3rlpy/models/torch/transformers.py index 0d7607c6..1fdd8e4d 100644 --- a/d3rlpy/models/torch/transformers.py +++ b/d3rlpy/models/torch/transformers.py @@ -1,6 +1,5 @@ import math from abc import ABCMeta, abstractmethod -from typing import Tuple import torch import torch.nn.functional as F @@ -386,7 +385,7 @@ def forward( action: torch.Tensor, return_to_go: torch.Tensor, timesteps: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: batch_size, context_size, _ = return_to_go.shape position_embedding = self._position_encoding(timesteps) @@ -475,7 +474,7 @@ def forward( observation_masks: torch.Tensor, observation_positions: torch.Tensor, action_masks: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # TODO: Support text and patch tokens assert tokens.ndim == 2 batch_size, context_size = tokens.shape diff --git a/d3rlpy/ope/torch/fqe_impl.py b/d3rlpy/ope/torch/fqe_impl.py index a1549573..ff4923b1 100644 --- a/d3rlpy/ope/torch/fqe_impl.py +++ b/d3rlpy/ope/torch/fqe_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, Union +from typing import Union import torch from torch import nn @@ -103,7 +103,7 @@ def inner_sample_action(self, x: torch.Tensor) -> torch.Tensor: def inner_update( self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: next_actions = self._algo.predict_best_action(batch.next_observations) q_tpn = self.compute_target(batch, next_actions) diff --git a/d3rlpy/optimizers/optimizers.py b/d3rlpy/optimizers/optimizers.py index e4467a02..2b965251 100644 --- a/d3rlpy/optimizers/optimizers.py +++ b/d3rlpy/optimizers/optimizers.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple +from typing import Any, Iterable, Mapping, Optional, Sequence from torch import nn from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop @@ -22,7 +22,7 @@ def _get_parameters_from_named_modules( - named_modules: Iterable[Tuple[str, nn.Module]] + named_modules: Iterable[tuple[str, nn.Module]], ) -> Sequence[nn.Parameter]: # retrieve unique set of parameters params_dict = {} @@ -121,7 +121,7 @@ class OptimizerFactory(DynamicConfig): def create( self, - named_modules: Iterable[Tuple[str, nn.Module]], + named_modules: Iterable[tuple[str, nn.Module]], lr: float, compiled: bool, ) -> OptimizerWrapper: @@ -152,7 +152,7 @@ def create( ) def create_optimizer( - self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float + self, named_modules: Iterable[tuple[str, nn.Module]], lr: float ) -> Optimizer: raise NotImplementedError @@ -182,7 +182,7 @@ class SGDFactory(OptimizerFactory): nesterov: bool = False def create_optimizer( - self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float + self, named_modules: Iterable[tuple[str, nn.Module]], lr: float ) -> Optimizer: return SGD( _get_parameters_from_named_modules(named_modules), @@ -218,13 +218,13 @@ class AdamFactory(OptimizerFactory): amsgrad: flag to use the AMSGrad variant of this algorithm. """ - betas: Tuple[float, float] = (0.9, 0.999) + betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0 amsgrad: bool = False def create_optimizer( - self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float + self, named_modules: Iterable[tuple[str, nn.Module]], lr: float ) -> Adam: return Adam( params=_get_parameters_from_named_modules(named_modules), @@ -260,13 +260,13 @@ class AdamWFactory(OptimizerFactory): amsgrad: flag to use the AMSGrad variant of this algorithm. """ - betas: Tuple[float, float] = (0.9, 0.999) + betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0 amsgrad: bool = False def create_optimizer( - self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float + self, named_modules: Iterable[tuple[str, nn.Module]], lr: float ) -> AdamW: return AdamW( _get_parameters_from_named_modules(named_modules), @@ -310,7 +310,7 @@ class RMSpropFactory(OptimizerFactory): centered: bool = True def create_optimizer( - self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float + self, named_modules: Iterable[tuple[str, nn.Module]], lr: float ) -> RMSprop: return RMSprop( _get_parameters_from_named_modules(named_modules), @@ -347,13 +347,13 @@ class GPTAdamWFactory(OptimizerFactory): amsgrad: flag to use the AMSGrad variant of this algorithm. """ - betas: Tuple[float, float] = (0.9, 0.999) + betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0 amsgrad: bool = False def create_optimizer( - self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float + self, named_modules: Iterable[tuple[str, nn.Module]], lr: float ) -> AdamW: named_modules = list(named_modules) params_dict = {} diff --git a/d3rlpy/serializable_config.py b/d3rlpy/serializable_config.py index 1e5ca1e2..bc279a83 100644 --- a/d3rlpy/serializable_config.py +++ b/d3rlpy/serializable_config.py @@ -2,11 +2,8 @@ from typing import ( Any, Callable, - Dict, Optional, Sequence, - Tuple, - Type, TypeVar, Union, cast, @@ -38,21 +35,21 @@ class SerializableConfig: def serialize(self) -> str: return self.to_json() # type: ignore - def serialize_to_dict(self) -> Dict[str, Any]: + def serialize_to_dict(self) -> dict[str, Any]: return self.to_dict() # type: ignore @classmethod - def deserialize(cls: Type[TConfig], serialized_config: str) -> TConfig: + def deserialize(cls: type[TConfig], serialized_config: str) -> TConfig: return cls.from_json(serialized_config) # type: ignore @classmethod def deserialize_from_dict( - cls: Type[TConfig], dict_config: Dict[str, Any] + cls: type[TConfig], dict_config: dict[str, Any] ) -> TConfig: return cls.from_dict(dict_config) # type: ignore @classmethod - def deserialize_from_file(cls: Type[TConfig], path: str) -> TConfig: + def deserialize_from_file(cls: type[TConfig], path: str) -> TConfig: with open(path, "r") as f: return cls.deserialize(f.read()) @@ -65,43 +62,43 @@ def get_type() -> str: @dataclasses.dataclass(frozen=True) class ConfigMetadata: - base_cls: Type[DynamicConfig] - encoder: Callable[[DynamicConfig], Dict[str, Any]] + base_cls: type[DynamicConfig] + encoder: Callable[[DynamicConfig], dict[str, Any]] decoder: Callable[ - [Dict[str, Any]], Union[DynamicConfig, Optional[DynamicConfig]] + [dict[str, Any]], Union[DynamicConfig, Optional[DynamicConfig]] ] - config_list: Dict[str, Type[DynamicConfig]] + config_list: dict[str, type[DynamicConfig]] - def add_config(self, name: str, new_config: Type[DynamicConfig]) -> None: + def add_config(self, name: str, new_config: type[DynamicConfig]) -> None: assert name not in self.config_list, f"{name} is already registered" self.config_list[name] = new_config -CONFIG_STORAGE: Dict[Type[DynamicConfig], ConfigMetadata] = {} +CONFIG_STORAGE: dict[type[DynamicConfig], ConfigMetadata] = {} def generate_config_registration( - base_cls: Type[TDynamicConfig], + base_cls: type[TDynamicConfig], default_factory: Optional[Callable[[], TDynamicConfig]] = None, -) -> Tuple[ - Callable[[Type[TDynamicConfig]], None], Callable[[], TDynamicConfig] +) -> tuple[ + Callable[[type[TDynamicConfig]], None], Callable[[], TDynamicConfig] ]: - CONFIG_LIST: Dict[str, Type[TDynamicConfig]] = {} + CONFIG_LIST: dict[str, type[TDynamicConfig]] = {} - def register_config(cls: Type[TDynamicConfig]) -> None: + def register_config(cls: type[TDynamicConfig]) -> None: assert issubclass(cls, base_cls) type_name = cls.get_type() is_registered = type_name in CONFIG_LIST assert not is_registered, f"{type_name} seems to be already registered" CONFIG_LIST[type_name] = cls - def _encoder(orig_config: TDynamicConfig) -> Dict[str, Any]: + def _encoder(orig_config: TDynamicConfig) -> dict[str, Any]: return { "type": orig_config.get_type(), "params": orig_config.serialize_to_dict(), } - def _decoder(dict_config: Dict[str, Any]) -> TDynamicConfig: + def _decoder(dict_config: dict[str, Any]) -> TDynamicConfig: name = dict_config["type"] params = dict_config["params"] return CONFIG_LIST[name].deserialize_from_dict(params) @@ -137,21 +134,21 @@ def make_field() -> TDynamicConfig: def generate_optional_config_generation( - base_cls: Type[TDynamicConfig], -) -> Tuple[ - Callable[[Type[TDynamicConfig]], None], + base_cls: type[TDynamicConfig], +) -> tuple[ + Callable[[type[TDynamicConfig]], None], Callable[[], Optional[TDynamicConfig]], ]: - CONFIG_LIST: Dict[str, Type[TDynamicConfig]] = {} + CONFIG_LIST: dict[str, type[TDynamicConfig]] = {} - def register_config(cls: Type[TDynamicConfig]) -> None: + def register_config(cls: type[TDynamicConfig]) -> None: assert issubclass(cls, base_cls) type_name = cls.get_type() is_registered = type_name in CONFIG_LIST assert not is_registered, f"{type_name} seems to be already registered" CONFIG_LIST[type_name] = cls - def _encoder(orig_config: Optional[TDynamicConfig]) -> Dict[str, Any]: + def _encoder(orig_config: Optional[TDynamicConfig]) -> dict[str, Any]: if orig_config is None: return {"type": "none", "params": {}} return { @@ -159,7 +156,7 @@ def _encoder(orig_config: Optional[TDynamicConfig]) -> Dict[str, Any]: "params": orig_config.serialize_to_dict(), } - def _decoder(dict_config: Dict[str, Any]) -> Optional[TDynamicConfig]: + def _decoder(dict_config: dict[str, Any]) -> Optional[TDynamicConfig]: name = dict_config["type"] params = dict_config["params"] if name == "none": @@ -184,7 +181,7 @@ def make_field() -> Optional[TDynamicConfig]: def generate_list_config_field( - base_cls: Type[TDynamicConfig], + base_cls: type[TDynamicConfig], ) -> Callable[[], Sequence[TDynamicConfig]]: assert base_cls in CONFIG_STORAGE @@ -192,11 +189,11 @@ def generate_list_config_field( def _encoder( orig_config: Sequence[TDynamicConfig], - ) -> Sequence[Dict[str, Any]]: + ) -> Sequence[dict[str, Any]]: return [config_metadata.encoder(config) for config in orig_config] def _decoder( - dict_config: Sequence[Dict[str, Any]] + dict_config: Sequence[dict[str, Any]], ) -> Sequence[TDynamicConfig]: configs = [config_metadata.decoder(config) for config in dict_config] return configs # type: ignore diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 61eff4c0..8fcd19ae 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -3,12 +3,10 @@ from typing import ( Any, BinaryIO, - Dict, Generic, Iterator, Optional, Sequence, - Tuple, TypeVar, Union, overload, @@ -122,7 +120,7 @@ def convert_to_torch_recursively( def convert_to_numpy_recursively( - array: Union[torch.Tensor, Sequence[torch.Tensor]] + array: Union[torch.Tensor, Sequence[torch.Tensor]], ) -> Union[NDArray, Sequence[NDArray]]: if isinstance(array, (list, tuple)): return [data.numpy() for data in array] @@ -370,12 +368,12 @@ def unwrap_ddp_model(model: _TModule) -> _TModule: class Checkpointer: - _modules: Dict[str, Union[nn.Module, OptimizerWrapperProto]] + _modules: dict[str, Union[nn.Module, OptimizerWrapperProto]] _device: str def __init__( self, - modules: Dict[str, Union[nn.Module, OptimizerWrapperProto]], + modules: dict[str, Union[nn.Module, OptimizerWrapperProto]], device: str, ): self._modules = modules @@ -396,7 +394,7 @@ def load(self, f: BinaryIO) -> None: v.load_state_dict(chkpt[k]) @property - def modules(self) -> Dict[str, Union[nn.Module, OptimizerWrapperProto]]: + def modules(self) -> dict[str, Union[nn.Module, OptimizerWrapperProto]]: return self._modules @@ -437,18 +435,21 @@ def reset_optimizer_states(self) -> None: if isinstance(v, OptimizerWrapperProto): v.optim.state = collections.defaultdict(dict) - def get_torch_modules(self) -> Dict[str, nn.Module]: - torch_modules: Dict[str, nn.Module] = {} + def get_torch_modules(self) -> dict[str, nn.Module]: + torch_modules: dict[str, nn.Module] = {} for k, v in asdict_without_copy(self).items(): if isinstance(v, nn.Module): torch_modules[k] = v return torch_modules - def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]: + def get_gradients(self) -> Iterator[tuple[str, Float32NDArray]]: for module_name, module in self.get_torch_modules().items(): for name, parameter in module.named_parameters(): if parameter.requires_grad and parameter.grad is not None: - yield f"{module_name}.{name}", parameter.grad.cpu().detach().numpy() + yield ( + f"{module_name}.{name}", + parameter.grad.cpu().detach().numpy(), + ) TCallable = TypeVar("TCallable") diff --git a/dev.requirements.txt b/dev.requirements.txt index 864017c5..272ff6aa 100644 --- a/dev.requirements.txt +++ b/dev.requirements.txt @@ -5,11 +5,9 @@ onnx matplotlib tensorboardX wandb -black mypy -pylint==2.13.5 numpy<2 -isort docformatter +ruff minari==0.4.2 gymnasium_robotics diff --git a/examples/custom_algo.py b/examples/custom_algo.py index 885df3a1..c73f5045 100644 --- a/examples/custom_algo.py +++ b/examples/custom_algo.py @@ -1,6 +1,6 @@ import copy import dataclasses -from typing import Dict, Sequence, cast +from typing import Sequence, cast import gym import torch @@ -49,7 +49,7 @@ def __init__( def inner_update( self, batch: d3rlpy.TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + ) -> dict[str, float]: self._modules.optim.zero_grad() with torch.no_grad(): diff --git a/examples/distributed_offline_training.py b/examples/distributed_offline_training.py index ba6d7471..2a68fd00 100644 --- a/examples/distributed_offline_training.py +++ b/examples/distributed_offline_training.py @@ -1,4 +1,3 @@ -from typing import Dict import d3rlpy @@ -34,7 +33,7 @@ def main() -> None: # disable logging on rank != 0 workers logger_adapter: d3rlpy.logging.LoggerAdapterFactory - evaluators: Dict[str, d3rlpy.metrics.EvaluatorProtocol] + evaluators: dict[str, d3rlpy.metrics.EvaluatorProtocol] if rank == 0: evaluators = {"environment": d3rlpy.metrics.EnvironmentEvaluator(env)} logger_adapter = d3rlpy.logging.FileAdapterFactory() diff --git a/mypy.ini b/mypy.ini index 83238165..4e09045b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.8 +python_version = 3.9 strict = True strict_optional = True disallow_untyped_defs = True diff --git a/pylintrc b/pylintrc deleted file mode 100644 index fdea5f83..00000000 --- a/pylintrc +++ /dev/null @@ -1,615 +0,0 @@ -[MASTER] - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-whitelist= - -# Specify a score threshold to be exceeded before program exits with error. -fail-under=10.0 - -# Add files or directories to the blacklist. They should be base names, not -# paths. -ignore=CVS - -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. -ignore-patterns= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python module names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=print-statement, - parameter-unpacking, - unpacking-in-except, - old-raise-syntax, - backtick, - long-suffix, - old-ne-operator, - old-octal-literal, - import-star-module-level, - non-ascii-bytes-literal, - raw-checker-failed, - bad-inline-option, - locally-disabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - use-symbolic-message-instead, - apply-builtin, - basestring-builtin, - buffer-builtin, - cmp-builtin, - coerce-builtin, - execfile-builtin, - file-builtin, - long-builtin, - raw_input-builtin, - reduce-builtin, - standarderror-builtin, - unicode-builtin, - xrange-builtin, - coerce-method, - delslice-method, - getslice-method, - setslice-method, - no-absolute-import, - old-division, - dict-iter-method, - dict-view-method, - next-method-called, - metaclass-assignment, - indexing-exception, - raising-string, - reload-builtin, - oct-method, - hex-method, - nonzero-method, - cmp-method, - input-builtin, - round-builtin, - intern-builtin, - unichr-builtin, - map-builtin-not-iterating, - zip-builtin-not-iterating, - range-builtin-not-iterating, - filter-builtin-not-iterating, - using-cmp-argument, - eq-without-hash, - div-method, - idiv-method, - rdiv-method, - exception-message-attribute, - invalid-str-codec, - sys-max-int, - bad-python3-import, - deprecated-string-function, - deprecated-str-translate-call, - deprecated-itertools-function, - deprecated-types-field, - next-method-defined, - dict-items-not-iterating, - dict-keys-not-iterating, - dict-values-not-iterating, - deprecated-operator-function, - deprecated-urllib-function, - xreadlines-attribute, - deprecated-sys-function, - exception-escape, - comprehension-escape, - duplicate-code, - missing-function-docstring, - missing-class-docstring, - missing-module-docstring, - too-few-public-methods, - invalid-name, - no-name-in-module, - import-outside-toplevel, - no-member, - too-many-arguments, - too-many-instance-attributes, - too-many-locals, - abstract-method, - not-callable, - fixme, - unused-argument, - line-too-long, - too-many-branches, - too-many-statements, - no-else-return, - too-many-lines, - no-self-use, - assignment-from-none, - too-many-public-methods, - stop-iteration-return, - unspecified-encoding, - unnecessary-lambda - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'error', 'warning', 'refactor', and 'convention' -# which contain the number of messages in each category, as well as 'statement' -# which is the total number of statements analyzed. This score is used by the -# global evaluation report (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit - - -[LOGGING] - -# The type of string formatting that logging methods do. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it work, -# install the python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - -# Regular expression of note tags to take in consideration. -#notes-rgx= - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=100 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. -#argument-rgx= - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names=foo, - bar, - baz, - toto, - tutu, - tata - -# Bad variable names regexes, separated by a comma. If names match any regex, -# they will always be refused -bad-names-rgxs= - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. -#class-attribute-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - ex, - Run, - _ - -# Good variable names regexes, separated by a comma. If names match any regex, -# they will always be accepted -good-names-rgxs= - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. -#variable-rgx= - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=no - -# This flag controls whether the implicit-str-concat should generate a warning -# on implicit string concatenation in sequences defined over several lines. -check-str-concat-over-line-jumps=no - - -[IMPORTS] - -# List of modules that can be imported at any level, not just the top level -# one. -allow-any-import-level= - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled). -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled). -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - -# Couples of modules and preferred modules, separated by a comma. -preferred-modules= - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp, - __post_init__ - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=cls - - -[DESIGN] - -# Maximum number of arguments for function / method. -max-args=5 - -# Maximum number of attributes for a class (see R0902). -max-attributes=7 - -# Maximum number of boolean expressions in an if statement (see R0916). -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=6 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception diff --git a/reproductions/offline/rebrac.py b/reproductions/offline/rebrac.py index 3098c8dc..a9a252d6 100644 --- a/reproductions/offline/rebrac.py +++ b/reproductions/offline/rebrac.py @@ -1,9 +1,8 @@ import argparse -from typing import Dict, Tuple import d3rlpy -BETA_TABLE: Dict[str, Tuple[float, float]] = { +BETA_TABLE: dict[str, tuple[float, float]] = { "halfcheetah-random": (0.001, 0.1), "halfcheetah-medium": (0.001, 0.01), "halfcheetah-expert": (0.01, 0.01), diff --git a/requirements.txt b/requirements.txt index 604cca3b..bb244847 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ tqdm>=4.66.1 h5py==2.10.0 gym==0.26.2 click==8.0.1 -typing-extensions==3.7.4.3 structlog==20.2.0 colorama==0.4.4 gymnasium==1.0.0 diff --git a/scripts/format b/scripts/format index 2ea5bf72..087593a5 100755 --- a/scripts/format +++ b/scripts/format @@ -1,20 +1,3 @@ #!/bin/bash -ex -if [[ -z $CI ]]; then - ISORT_ARG="" - BLACK_ARG="" - DOCFORMATTER_ARG="--in-place" -else - ISORT_ARG="--check --diff" - BLACK_ARG="--check" - DOCFORMATTER_ARG="--check --diff" -fi -# format package imports -isort -l 80 --profile black $ISORT_ARG d3rlpy tests setup.py reproductions examples - -# use black for the better type annotations -black -l 80 $BLACK_ARG d3rlpy tests setup.py reproductions examples - -# format docstrings -docformatter $DOCFORMATTER_ARG --black --wrap-summaries 80 --wrap-descriptions 80 -r d3rlpy diff --git a/scripts/lint b/scripts/lint index 3b666458..d3c888cf 100755 --- a/scripts/lint +++ b/scripts/lint @@ -1,7 +1,18 @@ -#!/bin/bash -eux +#!/bin/bash -ex + +if [[ -z $CI ]]; then + RUFF_ARG="check --fix" + DOCFORMATTER_ARG="--in-place" +else + RUFF_ARG="check" + DOCFORMATTER_ARG="--check --diff" +fi + +# formatter and linter +ruff $RUFF_ARG d3rlpy tests examples reproductions setup.py + +# format docstrings +docformatter $DOCFORMATTER_ARG --black --wrap-summaries 80 --wrap-descriptions 80 -r d3rlpy # type check mypy d3rlpy reproductions tests examples - -# code-format check -pylint d3rlpy reproductions tests examples diff --git a/setup.py b/setup.py index b6be1cb6..2cf5b81e 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ if __name__ == "__main__": setup( name="d3rlpy", - version=__version__, + version=__version__, # noqa description="An offline deep reinforcement learning library", long_description=open("README.md").read(), long_description_content_type="text/markdown", @@ -24,9 +24,9 @@ "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", "Operating System :: POSIX :: Linux", "Operating System :: Microsoft :: Windows", @@ -38,14 +38,13 @@ "h5py", "gym>=0.26.0", "click", - "typing-extensions", "structlog", "colorama", "dataclasses-json", "gymnasium>=1.0.0", ], packages=find_packages(exclude=["tests*"]), - python_requires=">=3.8.0", + python_requires=">=3.9.0", zip_safe=True, entry_points={"console_scripts": ["d3rlpy=d3rlpy.cli:cli"]}, ) diff --git a/tests/algos/qlearning/algo_test.py b/tests/algos/qlearning/algo_test.py index b1ece560..8463041e 100644 --- a/tests/algos/qlearning/algo_test.py +++ b/tests/algos/qlearning/algo_test.py @@ -152,7 +152,7 @@ def predict_tester( x = create_observations(observation_shape, 100) y = algo.predict(x) if algo.get_action_type() == ActionSpace.DISCRETE: - assert y.shape == (100,) # type: ignore + assert y.shape == (100,) else: assert y.shape == (100, action_size) @@ -166,7 +166,7 @@ def sample_action_tester( x = create_observations(observation_shape, 100) y = algo.sample_action(x) if algo.get_action_type() == ActionSpace.DISCRETE: - assert y.shape == (100,) # type: ignore + assert y.shape == (100,) else: assert y.shape == (100, action_size) @@ -238,7 +238,7 @@ def predict_value_tester( action = np.random.random((100, action_size)) value = algo.predict_value(x, action) - assert value.shape == (100,) # type: ignore + assert value.shape == (100,) def save_and_load_tester( diff --git a/tests/algos/qlearning/test_random_policy.py b/tests/algos/qlearning/test_random_policy.py index 15836b28..407906e9 100644 --- a/tests/algos/qlearning/test_random_policy.py +++ b/tests/algos/qlearning/test_random_policy.py @@ -47,10 +47,10 @@ def test_discrete_random_policy( # check predict action = algo.predict(x) - assert action.shape == (batch_size,) # type: ignore + assert action.shape == (batch_size,) assert np.all(action < action_size) # check sample_action action = algo.sample_action(x) - assert action.shape == (batch_size,) # type: ignore + assert action.shape == (batch_size,) assert np.all(action < action_size) diff --git a/tests/algos/transformer/algo_test.py b/tests/algos/transformer/algo_test.py index 0f2dfc18..a5137ff1 100644 --- a/tests/algos/transformer/algo_test.py +++ b/tests/algos/transformer/algo_test.py @@ -1,5 +1,5 @@ import os -from typing import Any, List +from typing import Any from unittest.mock import Mock import numpy as np @@ -125,9 +125,9 @@ def predict_tester( ) y = algo.predict(inpt) if algo.get_action_type() == ActionSpace.DISCRETE: - assert y.shape == (action_size,) # type: ignore + assert y.shape == (action_size,) else: - assert y.shape == (action_size,) # type: ignore + assert y.shape == (action_size,) def save_and_load_tester( @@ -225,7 +225,7 @@ def stateful_wrapper_tester( assert isinstance(action, int) else: assert isinstance(action, np.ndarray) - assert action.shape == (action_size,) # type: ignore + assert action.shape == (action_size,) wrapper.reset() # check reset @@ -252,7 +252,7 @@ def save_policy_tester( algo.save_policy(os.path.join("test_data", "model.pt")) policy = torch.jit.load(os.path.join("test_data", "model.pt")) - inputs: List[Any] = [] + inputs: list[Any] = [] torch_observations = create_torch_observations( observation_shape, algo.config.context_size ) diff --git a/tests/base_test.py b/tests/base_test.py index dec99da4..62263b7c 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -19,7 +19,7 @@ def _check_reconst_algo( new_algo: LearnableBase[ImplBase, LearnableConfig], ) -> None: assert new_algo.impl is not None - assert type(new_algo) == type(algo) + assert type(new_algo) is type(algo) assert algo.observation_shape is not None assert new_algo.observation_shape is not None if isinstance(algo.observation_shape[0], int): @@ -35,7 +35,7 @@ def _check_reconst_algo( if algo.observation_scaler is None: assert new_algo.observation_scaler is None else: - assert type(algo.observation_scaler) == type( + assert type(algo.observation_scaler) is type( new_algo.observation_scaler ) @@ -43,13 +43,13 @@ def _check_reconst_algo( if algo.action_scaler is None: assert new_algo.action_scaler is None else: - assert type(algo.action_scaler) == type(new_algo.action_scaler) + assert type(algo.action_scaler) is type(new_algo.action_scaler) # check reward scaler if algo.reward_scaler is None: assert new_algo.reward_scaler is None else: - assert type(algo.reward_scaler) == type(new_algo.reward_scaler) + assert type(algo.reward_scaler) is type(new_algo.reward_scaler) def from_json_tester( diff --git a/tests/dataset/test_mini_batch.py b/tests/dataset/test_mini_batch.py index 516fdb16..0992d11d 100644 --- a/tests/dataset/test_mini_batch.py +++ b/tests/dataset/test_mini_batch.py @@ -92,10 +92,10 @@ def test_trajectory_mini_batch( length, *observation_shape, ) - assert batch.actions.shape == (batch_size, length, action_size) # type: ignore - assert batch.rewards.shape == (batch_size, length, 1) # type: ignore - assert batch.returns_to_go.shape == (batch_size, length, 1) # type: ignore - assert batch.terminals.shape == (batch_size, length, 1) # type: ignore + assert batch.actions.shape == (batch_size, length, action_size) + assert batch.rewards.shape == (batch_size, length, 1) + assert batch.returns_to_go.shape == (batch_size, length, 1) + assert batch.terminals.shape == (batch_size, length, 1) assert batch.timesteps.shape == (batch_size, length) assert batch.masks.shape == (batch_size, length) assert batch.length == length diff --git a/tests/dataset/test_trajectory_slicer.py b/tests/dataset/test_trajectory_slicer.py index 301b509f..2f199489 100644 --- a/tests/dataset/test_trajectory_slicer.py +++ b/tests/dataset/test_trajectory_slicer.py @@ -55,8 +55,8 @@ def test_basic_trajectory_slicer( assert traj.rewards.shape == (size, 1) assert traj.returns_to_go.shape == (size, 1) assert traj.terminals.shape == (size, 1) - assert traj.timesteps.shape == (size,) # type: ignore - assert traj.masks.shape == (size,) # type: ignore + assert traj.timesteps.shape == (size,) + assert traj.masks.shape == (size,) assert traj.length == size # check values @@ -142,8 +142,8 @@ def test_frame_stack_trajectory_slicer( assert traj.rewards.shape == (size, 1) assert traj.returns_to_go.shape == (size, 1) assert traj.terminals.shape == (size, 1) - assert traj.timesteps.shape == (size,) # type: ignore - assert traj.masks.shape == (size,) # type: ignore + assert traj.timesteps.shape == (size,) + assert traj.masks.shape == (size,) assert traj.length == size # check values diff --git a/tests/dataset/test_utils.py b/tests/dataset/test_utils.py index 6e6aa4a6..7596c1e9 100644 --- a/tests/dataset/test_utils.py +++ b/tests/dataset/test_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence, Tuple, Union, cast +from typing import Any, Sequence, Union, cast import gym import gymnasium @@ -67,7 +67,7 @@ def test_create_zero_observation(observation_shape: Shape) -> None: @pytest.mark.parametrize("length", [100]) @pytest.mark.parametrize("index", [(0, 5)]) def test_slice_observations( - observation_shape: Shape, length: int, index: Tuple[int, int] + observation_shape: Shape, length: int, index: tuple[int, int] ) -> None: observations = create_observations(observation_shape, length) diff --git a/tests/dummy_env.py b/tests/dummy_env.py index 932df31c..77aa43aa 100644 --- a/tests/dummy_env.py +++ b/tests/dummy_env.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any import gym import numpy as np @@ -23,11 +23,11 @@ def __init__(self, grayscale: bool = True, squeeze: bool = False): def step( self, action: int - ) -> Tuple[NDArray, float, bool, bool, Dict[str, Any]]: + ) -> tuple[NDArray, float, bool, bool, dict[str, Any]]: observation = self.observation_space.sample() reward = np.random.random() return observation, reward, False, self.t % 80 == 0, {} - def reset(self, **kwargs: Any) -> Tuple[NDArray, Dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[NDArray, dict[str, Any]]: self.t = 1 return self.observation_space.sample(), {} diff --git a/tests/envs/test_wrappers.py b/tests/envs/test_wrappers.py index 81a5ef80..287bdf60 100644 --- a/tests/envs/test_wrappers.py +++ b/tests/envs/test_wrappers.py @@ -18,11 +18,11 @@ def test_channel_first() -> None: # check reset observation, _ = wrapper.reset() - assert observation.shape == (channel, width, height) # type: ignore + assert observation.shape == (channel, width, height) # check step observation, _, _, _, _ = wrapper.step(wrapper.action_space.sample()) - assert observation.shape == (channel, width, height) # type: ignore + assert observation.shape == (channel, width, height) # check with algorithm dqn = DQNConfig().create() @@ -40,11 +40,11 @@ def test_channel_first_with_2_dim_obs() -> None: # check reset observation, _ = wrapper.reset() - assert observation.shape == (1, width, height) # type: ignore + assert observation.shape == (1, width, height) # check step observation, _, _, _, _ = wrapper.step(wrapper.action_space.sample()) - assert observation.shape == (1, width, height) # type: ignore + assert observation.shape == (1, width, height) # check with algorithm dqn = DQNConfig().create() @@ -63,11 +63,11 @@ def test_frame_stack(num_stack: int) -> None: # check reset observation, _ = wrapper.reset() - assert observation.shape == (num_stack, width, height) # type: ignore + assert observation.shape == (num_stack, width, height) # check step observation, _, _, _, _ = wrapper.step(wrapper.action_space.sample()) - assert observation.shape == (num_stack, width, height) # type: ignore + assert observation.shape == (num_stack, width, height) # check with algorithm dqn = DQNConfig().create() @@ -84,11 +84,11 @@ def test_atari(is_eval: bool) -> None: # check reset observation, _ = env.reset() - assert observation.shape == (1, 84, 84) # type: ignore + assert observation.shape == (1, 84, 84) # check step observation, _, _, _, _ = env.step(env.action_space.sample()) - assert observation.shape == (1, 84, 84) # type: ignore + assert observation.shape == (1, 84, 84) # @pytest.mark.parametrize("tuple_observation", [True, False]) diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py index cf484a67..50f3977b 100644 --- a/tests/logging/test_logger.py +++ b/tests/logging/test_logger.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any import pytest from torch import nn @@ -18,7 +18,7 @@ def __init__(self, experiment_name: str): self.is_close_called = False self.is_watch_model_called = False - def write_params(self, params: Dict[str, Any]) -> None: + def write_params(self, params: dict[str, Any]) -> None: self.is_write_params_called = True def before_write_metric(self, epoch: int, step: int) -> None: @@ -57,7 +57,7 @@ def create( class StubModules: - def get_torch_modules(self) -> List[nn.Module]: + def get_torch_modules(self) -> list[nn.Module]: return [] diff --git a/tests/metrics/test_evaluators.py b/tests/metrics/test_evaluators.py index 997e4564..837c411b 100644 --- a/tests/metrics/test_evaluators.py +++ b/tests/metrics/test_evaluators.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Sequence +from typing import Callable, Optional, Sequence import numpy as np import pytest @@ -98,7 +98,7 @@ def ref_td_error_score( terminals: NDArray, gamma: float, reward_scaler: Optional[RewardScaler], -) -> List[float]: +) -> list[float]: if reward_scaler: rewards = reward_scaler.transform_numpy(rewards) values = predict_value(observations, actions) @@ -140,7 +140,7 @@ def test_td_error_scorer( algo = DummyAlgo(A, gamma, reward_scaler=reward_scaler) - ref_errors: List[float] = [] + ref_errors: list[float] = [] for episode in episodes: batch = _convert_episode_to_batch(episode) ref_error = ref_td_error_score( @@ -199,7 +199,7 @@ def ref_discounted_sum_of_advantage_score( dataset_actions: NDArray, policy_actions: NDArray, gamma: float, -) -> List[float]: +) -> list[float]: dataset_values = predict_value(observations, dataset_actions) policy_values = predict_value(observations, policy_actions) advantages = (dataset_values - policy_values).reshape(-1).tolist() diff --git a/tests/metrics/test_utility.py b/tests/metrics/test_utility.py index e32e48d6..a4ce1d7d 100644 --- a/tests/metrics/test_utility.py +++ b/tests/metrics/test_utility.py @@ -1,6 +1,6 @@ from functools import reduce from operator import mul -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence import gym import numpy as np @@ -65,14 +65,14 @@ def __init__( def step( self, action: NDArray - ) -> Tuple[NDArray, float, bool, bool, Dict[str, Any]]: + ) -> tuple[NDArray, float, bool, bool, dict[str, Any]]: self.t += 1 observation = self.observations[self.episode - 1, self.t] reward = np.mean(observation) + np.mean(action) done = self.t == self.episode_length return observation, float(reward), done, False, {} - def reset(self, **kwargs: Any) -> Tuple[NDArray, Dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[NDArray, dict[str, Any]]: self.t = 0 self.episode += 1 return self.observations[self.episode - 1, 0], {} diff --git a/tests/models/torch/q_functions/test_ensemble_q_function.py b/tests/models/torch/q_functions/test_ensemble_q_function.py index 8be8ef13..1c1d1b60 100644 --- a/tests/models/torch/q_functions/test_ensemble_q_function.py +++ b/tests/models/torch/q_functions/test_ensemble_q_function.py @@ -1,4 +1,3 @@ -from typing import List import pytest import torch @@ -94,7 +93,7 @@ def test_discrete_ensemble_q_function_forwarder( n_quantiles: int, embed_size: int, ) -> None: - forwarders: List[DiscreteQFunctionForwarder] = [] + forwarders: list[DiscreteQFunctionForwarder] = [] for _ in range(ensemble_size): encoder = DummyEncoder(observation_shape) forwarder: DiscreteQFunctionForwarder @@ -214,7 +213,7 @@ def test_ensemble_continuous_q_function( n_quantiles: int, embed_size: int, ) -> None: - forwarders: List[ContinuousQFunctionForwarder] = [] + forwarders: list[ContinuousQFunctionForwarder] = [] for _ in range(ensemble_size): forwarder: ContinuousQFunctionForwarder encoder = DummyEncoderWithAction(observation_shape, action_size) diff --git a/tests/models/torch/test_encoders.py b/tests/models/torch/test_encoders.py index 3e7df8c1..92c9af9e 100644 --- a/tests/models/torch/test_encoders.py +++ b/tests/models/torch/test_encoders.py @@ -1,5 +1,5 @@ # pylint: disable=protected-access -from typing import List, Optional, Sequence, Tuple +from typing import Optional, Sequence import pytest import torch @@ -23,8 +23,8 @@ @pytest.mark.parametrize("activation", [torch.nn.ReLU()]) @pytest.mark.parametrize("last_activation", [None, torch.nn.ReLU()]) def test_pixel_encoder( - shapes: Tuple[Sequence[int], int], - filters: List[List[int]], + shapes: tuple[Sequence[int], int], + filters: list[list[int]], feature_size: int, batch_size: int, use_batch_norm: bool, @@ -72,9 +72,9 @@ def test_pixel_encoder( @pytest.mark.parametrize("activation", [torch.nn.ReLU()]) @pytest.mark.parametrize("last_activation", [None, torch.nn.ReLU()]) def test_pixel_encoder_with_action( - shapes: Tuple[Sequence[int], int], + shapes: tuple[Sequence[int], int], action_size: int, - filters: List[List[int]], + filters: list[list[int]], feature_size: int, batch_size: int, use_batch_norm: bool, diff --git a/tests/preprocessing/test_base.py b/tests/preprocessing/test_base.py index 2e5438de..0a79e97c 100644 --- a/tests/preprocessing/test_base.py +++ b/tests/preprocessing/test_base.py @@ -13,4 +13,4 @@ def test_add_leading_dims() -> None: def test_add_leading_dims_numpy() -> None: x = np.random.random(3) target = np.random.random((1, 2, 3)) - assert add_leading_dims_numpy(x, target).shape == (1, 1, 3) # type: ignore + assert add_leading_dims_numpy(x, target).shape == (1, 1, 3) diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index a61113c6..d7715813 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -1,7 +1,7 @@ import copy import dataclasses from io import BytesIO -from typing import Any, Dict, Optional, Sequence +from typing import Any, Optional, Sequence from unittest.mock import Mock import numpy as np @@ -94,8 +94,8 @@ def test_sync_optimizer_state(input_size: int, output_size: int) -> None: # check if state is synced targ_state = targ_optim.state_dict()["state"] state = optim.state_dict()["state"] - for i, l in targ_state.items(): - for k, v in l.items(): + for i, v_dict in targ_state.items(): + for k, v in v_dict.items(): if isinstance(v, int): assert v == state[i][k] else: @@ -136,7 +136,7 @@ def eval_api_func(self) -> None: assert self.fc2.training -def check_if_same_dict(a: Dict[str, Any], b: Dict[str, Any]) -> None: +def check_if_same_dict(a: dict[str, Any], b: dict[str, Any]) -> None: for k, v in a.items(): if isinstance(v, torch.Tensor): assert (b[k] == v).all() diff --git a/tests/testing_utils.py b/tests/testing_utils.py index f87b09f3..c5589b70 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Tuple, cast, overload +from typing import Optional, Sequence, cast, overload import numpy as np import torch @@ -262,7 +262,7 @@ def create_partial_trajectory( def create_scaler_tuple( name: Optional[str], observation_shape: Shape, -) -> Tuple[ +) -> tuple[ Optional[ObservationScaler], Optional[ActionScaler], Optional[RewardScaler] ]: if name is None: