From 8cb7c4d596333f15de4677824baabf873f125382 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 18 Feb 2024 10:54:46 +0900 Subject: [PATCH] Apply upgraded black format --- d3rlpy/algos/qlearning/awac.py | 1 + d3rlpy/algos/qlearning/bc.py | 2 + d3rlpy/algos/qlearning/bcq.py | 2 + d3rlpy/algos/qlearning/bear.py | 1 + d3rlpy/algos/qlearning/cql.py | 2 + d3rlpy/algos/qlearning/crr.py | 1 + d3rlpy/algos/qlearning/ddpg.py | 1 + d3rlpy/algos/qlearning/dqn.py | 2 + d3rlpy/algos/qlearning/iql.py | 1 + d3rlpy/algos/qlearning/nfq.py | 1 + d3rlpy/algos/qlearning/plas.py | 2 + d3rlpy/algos/qlearning/random_policy.py | 1 + d3rlpy/algos/qlearning/sac.py | 2 + d3rlpy/algos/qlearning/td3.py | 1 + d3rlpy/algos/qlearning/td3_plus_bc.py | 1 + d3rlpy/algos/transformer/action_samplers.py | 1 + d3rlpy/algos/transformer/base.py | 5 +- d3rlpy/base.py | 6 +- d3rlpy/dataset/buffers.py | 2 + d3rlpy/dataset/components.py | 5 ++ d3rlpy/dataset/episode_generator.py | 1 + d3rlpy/dataset/mini_batch.py | 2 + d3rlpy/dataset/replay_buffer.py | 2 + d3rlpy/dataset/trajectory_slicers.py | 1 + d3rlpy/dataset/transition_pickers.py | 2 + d3rlpy/dataset/utils.py | 71 +++++++------------ d3rlpy/envs/wrappers.py | 22 +++--- d3rlpy/interface.py | 30 +++----- d3rlpy/logging/file_adapter.py | 2 + d3rlpy/logging/logger.py | 3 +- d3rlpy/logging/tensorboard_adapter.py | 2 + d3rlpy/logging/utils.py | 1 + d3rlpy/metrics/evaluators.py | 9 +++ d3rlpy/models/encoders.py | 32 +++++---- .../torch/q_functions/ensemble_q_function.py | 16 +++-- d3rlpy/ope/fqe.py | 1 + d3rlpy/preprocessing/action_scalers.py | 1 + d3rlpy/preprocessing/observation_scalers.py | 8 ++- d3rlpy/preprocessing/reward_scalers.py | 6 ++ d3rlpy/tokenizers/tokenizers.py | 6 +- d3rlpy/torch_utility.py | 8 +-- tests/testing_utils.py | 30 +++----- 42 files changed, 165 insertions(+), 131 deletions(-) diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index d8414e6e..7e1ae56c 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -71,6 +71,7 @@ class AWACConfig(LearnableConfig): :math:`A^\pi(s_t, a_t)`. n_critics (int): Number of Q functions for ensemble. """ + actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 actor_optim_factory: OptimizerFactory = make_optimizer_field() diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index 507cf9e2..c614e80d 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -50,6 +50,7 @@ class BCConfig(LearnableConfig): Observation preprocessor. action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. """ + batch_size: int = 100 learning_rate: float = 1e-3 policy_type: str = "deterministic" @@ -133,6 +134,7 @@ class DiscreteBCConfig(LearnableConfig): observation_scaler (d3rlpy.preprocessing.ObservationScaler): Observation preprocessor. """ + batch_size: int = 100 learning_rate: float = 1e-3 optim_factory: OptimizerFactory = make_optimizer_field() diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index a66b3d43..d905f7a0 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -138,6 +138,7 @@ class BCQConfig(LearnableConfig): functions. If this is large, RL training would be more stabilized. beta (float): KL reguralization term for Conditional VAE. """ + actor_learning_rate: float = 1e-3 critic_learning_rate: float = 1e-3 imitator_learning_rate: float = 1e-3 @@ -323,6 +324,7 @@ class DiscreteBCQConfig(LearnableConfig): share_encoder (bool): Flag to share encoder between Q-function and imitation models. """ + learning_rate: float = 6.25e-5 optim_factory: OptimizerFactory = make_optimizer_field() encoder_factory: EncoderFactory = make_encoder_field() diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index b02bf35c..d6dfc7be 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -115,6 +115,7 @@ class BEARConfig(LearnableConfig): warmup_steps (int): Number of steps to warmup the policy function. """ + actor_learning_rate: float = 1e-4 critic_learning_rate: float = 3e-4 imitator_learning_rate: float = 3e-4 diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 3650f8f8..d03fc7cd 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -100,6 +100,7 @@ class CQLConfig(LearnableConfig): :math:`\log{\sum_a \exp{Q(s, a)}}`. soft_q_backup (bool): Flag to use SAC-style backup. """ + actor_learning_rate: float = 1e-4 critic_learning_rate: float = 3e-4 temp_learning_rate: float = 1e-4 @@ -256,6 +257,7 @@ class DiscreteCQLConfig(LearnableConfig): network. alpha (float): math:`\alpha` value above. """ + learning_rate: float = 6.25e-5 optim_factory: OptimizerFactory = make_optimizer_field() encoder_factory: EncoderFactory = make_encoder_field() diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index d690a994..fd29cc56 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -100,6 +100,7 @@ class CRRConfig(LearnableConfig): update_actor_interval (int): Interval to update policy function used with ``hard`` target update. """ + actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 actor_optim_factory: OptimizerFactory = make_optimizer_field() diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index 5e83f78e..e2dc6d1e 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -70,6 +70,7 @@ class DDPGConfig(LearnableConfig): tau (float): Target network synchronization coefficiency. n_critics (int): Number of Q functions for ensemble. """ + batch_size: int = 256 actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index ed21b1d0..ff729d5a 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -45,6 +45,7 @@ class DQNConfig(LearnableConfig): n_critics (int): Number of Q functions for ensemble. target_update_interval (int): Interval to update the target network. """ + batch_size: int = 32 learning_rate: float = 6.25e-5 optim_factory: OptimizerFactory = make_optimizer_field() @@ -147,6 +148,7 @@ class DoubleDQNConfig(DQNConfig): target_update_interval (int): Interval to synchronize the target network. """ + batch_size: int = 32 learning_rate: float = 6.25e-5 optim_factory: OptimizerFactory = make_optimizer_field() diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index a95863cb..4f1ce04c 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -81,6 +81,7 @@ class IQLConfig(LearnableConfig): :math:`\beta`. max_weight (float): Maximum advantage weight value to clip. """ + actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 actor_optim_factory: OptimizerFactory = make_optimizer_field() diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 46e889ac..245b473c 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -48,6 +48,7 @@ class NFQConfig(LearnableConfig): gamma (float): Discount factor. n_critics (int): Number of Q functions for ensemble. """ + learning_rate: float = 6.25e-5 optim_factory: OptimizerFactory = make_optimizer_field() encoder_factory: EncoderFactory = make_encoder_field() diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 0db66d53..05b5d625 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -78,6 +78,7 @@ class PLASConfig(LearnableConfig): warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. """ + actor_learning_rate: float = 1e-4 critic_learning_rate: float = 1e-3 imitator_learning_rate: float = 1e-4 @@ -239,6 +240,7 @@ class PLASWithPerturbationConfig(PLASConfig): warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. """ + action_flexibility: float = 0.05 def create(self, device: DeviceArg = False) -> "PLASWithPerturbation": diff --git a/d3rlpy/algos/qlearning/random_policy.py b/d3rlpy/algos/qlearning/random_policy.py index 2f3a3c4d..8190c0da 100644 --- a/d3rlpy/algos/qlearning/random_policy.py +++ b/d3rlpy/algos/qlearning/random_policy.py @@ -31,6 +31,7 @@ class RandomPolicyConfig(LearnableConfig): normal_std (float): Standard deviation of the normal distribution. This is only used when ``distribution='normal'``. """ + distribution: str = "uniform" normal_std: float = 1.0 diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index 20a9cec3..c31cb69e 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -95,6 +95,7 @@ class SACConfig(LearnableConfig): n_critics (int): Number of Q functions for ensemble. initial_temperature (float): Initial temperature value. """ + actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 temp_learning_rate: float = 3e-4 @@ -243,6 +244,7 @@ class DiscreteSACConfig(LearnableConfig): n_critics (int): Number of Q functions for ensemble. initial_temperature (float): Initial temperature value. """ + actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 temp_learning_rate: float = 3e-4 diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index 2633bce0..58ba5757 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -75,6 +75,7 @@ class TD3Config(LearnableConfig): update_actor_interval (int): Interval to update policy function described as `delayed policy update` in the paper. """ + actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 actor_optim_factory: OptimizerFactory = make_optimizer_field() diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index ae1940fd..55d08c25 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -66,6 +66,7 @@ class TD3PlusBCConfig(LearnableConfig): update_actor_interval (int): Interval to update policy function described as `delayed policy update` in the paper. """ + actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 actor_optim_factory: OptimizerFactory = make_optimizer_field() diff --git a/d3rlpy/algos/transformer/action_samplers.py b/d3rlpy/algos/transformer/action_samplers.py index 180543c0..43cd8ce6 100644 --- a/d3rlpy/algos/transformer/action_samplers.py +++ b/d3rlpy/algos/transformer/action_samplers.py @@ -48,6 +48,7 @@ class SoftmaxTransformerActionSampler(TransformerActionSampler): Args: temperature (int): Softmax temperature. """ + _temperature: float def __init__(self, temperature: float = 1.0): diff --git a/d3rlpy/algos/transformer/base.py b/d3rlpy/algos/transformer/base.py index d26d6e81..3879a268 100644 --- a/d3rlpy/algos/transformer/base.py +++ b/d3rlpy/algos/transformer/base.py @@ -46,7 +46,7 @@ def predict(self, inpt: TorchTransformerInput) -> torch.Tensor: @abstractmethod def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: - ... + raise NotImplementedError @train_api def update( @@ -58,7 +58,7 @@ def update( def inner_update( self, batch: TorchTrajectoryMiniBatch, grad_step: int ) -> Dict[str, float]: - pass + raise NotImplementedError @dataclasses.dataclass() @@ -101,6 +101,7 @@ class StatefulTransformerWrapper(Generic[TTransformerImpl, TTransformerConfig]): target_return (float): Target return. action_sampler (d3rlpy.algos.TransformerActionSampler): Action sampler. """ + _algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]" _target_return: float _action_sampler: TransformerActionSampler diff --git a/d3rlpy/base.py b/d3rlpy/base.py index caad56d8..bed3f38d 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -97,9 +97,9 @@ def unwrap_models_by_ddp(self) -> None: class LearnableConfig(DynamicConfig): batch_size: int = 256 gamma: float = 0.99 - observation_scaler: Optional[ - ObservationScaler - ] = make_observation_scaler_field() + observation_scaler: Optional[ObservationScaler] = ( + make_observation_scaler_field() + ) action_scaler: Optional[ActionScaler] = make_action_scaler_field() reward_scaler: Optional[RewardScaler] = make_reward_scaler_field() diff --git a/d3rlpy/dataset/buffers.py b/d3rlpy/dataset/buffers.py index 9e91ab19..d2cde6e1 100644 --- a/d3rlpy/dataset/buffers.py +++ b/d3rlpy/dataset/buffers.py @@ -44,6 +44,7 @@ def __getitem__(self, index: int) -> Tuple[EpisodeBase, int]: class InfiniteBuffer(BufferProtocol): r"""Buffer with unlimited capacity.""" + _transitions: List[Tuple[EpisodeBase, int]] _episodes: List[EpisodeBase] @@ -78,6 +79,7 @@ class FIFOBuffer(BufferProtocol): Args: limit (int): buffer capacity. """ + _transitions: Deque[Tuple[EpisodeBase, int]] _episodes: List[EpisodeBase] _limit: int diff --git a/d3rlpy/dataset/components.py b/d3rlpy/dataset/components.py index 4d1096c9..7429c5c0 100644 --- a/d3rlpy/dataset/components.py +++ b/d3rlpy/dataset/components.py @@ -40,6 +40,7 @@ class Signature: dtype: List of numpy data types. shape: List of array shapes. """ + dtype: Sequence[DType] shape: Sequence[Sequence[int]] @@ -69,6 +70,7 @@ class Transition: terminal: Flag of environment termination. interval: Timesteps between ``observation`` and ``next_observation``. """ + observation: Observation # (...) action: NDArray # (...) reward: Float32NDArray # (1,) @@ -130,6 +132,7 @@ class PartialTrajectory: masks: Sequence of masks that represent padding. length: Sequence length. """ + observations: ObservationSequence # (L, ...) actions: NDArray # (L, ...) rewards: Float32NDArray # (L, 1) @@ -350,6 +353,7 @@ class Episode: rewards: Sequence of rewards. terminated: Flag of environment termination. """ + observations: ObservationSequence actions: NDArray rewards: Float32NDArray @@ -422,6 +426,7 @@ class DatasetInfo: this represents dimension of action vectors. For discrete action-space, this represents the number of discrete actions. """ + observation_signature: Signature action_signature: Signature reward_signature: Signature diff --git a/d3rlpy/dataset/episode_generator.py b/d3rlpy/dataset/episode_generator.py index ab230126..af521be0 100644 --- a/d3rlpy/dataset/episode_generator.py +++ b/d3rlpy/dataset/episode_generator.py @@ -32,6 +32,7 @@ class EpisodeGenerator(EpisodeGeneratorProtocol): terminals: Sequence of environment terminal flags. timeouts: Sequence of timeout flags. """ + _observations: ObservationSequence _actions: NDArray _rewards: Float32NDArray diff --git a/d3rlpy/dataset/mini_batch.py b/d3rlpy/dataset/mini_batch.py index 965281b4..fa09fd82 100644 --- a/d3rlpy/dataset/mini_batch.py +++ b/d3rlpy/dataset/mini_batch.py @@ -30,6 +30,7 @@ class TransitionMiniBatch: intervals: Batched timesteps between observations and next observations. """ + observations: Union[Float32NDArray, Sequence[Float32NDArray]] # (B, ...) actions: Float32NDArray # (B, ...) rewards: Float32NDArray # (B, 1) @@ -146,6 +147,7 @@ class TrajectoryMiniBatch: masks: Batched masks that represent padding. length: Length of trajectories. """ + observations: Union[Float32NDArray, Sequence[Float32NDArray]] # (B, L, ...) actions: Float32NDArray # (B, L, ...) rewards: Float32NDArray # (B, L, 1) diff --git a/d3rlpy/dataset/replay_buffer.py b/d3rlpy/dataset/replay_buffer.py index 8418c1b6..f6f9056e 100644 --- a/d3rlpy/dataset/replay_buffer.py +++ b/d3rlpy/dataset/replay_buffer.py @@ -329,6 +329,7 @@ class ReplayBuffer(ReplayBufferBase): for online training. ``cache_size`` needs to be greater than the maximum possible episode length. """ + _buffer: BufferProtocol _transition_picker: TransitionPickerProtocol _trajectory_slicer: TrajectorySlicerProtocol @@ -587,6 +588,7 @@ class MixedReplayBuffer(ReplayBufferBase): secondary_mix_ratio (float): Ratio to sample mini-batches from the secondary replay buffer. """ + _primary_replay_buffer: ReplayBufferBase _secondary_replay_buffer: ReplayBufferBase _secondary_mix_ratio: float diff --git a/d3rlpy/dataset/trajectory_slicers.py b/d3rlpy/dataset/trajectory_slicers.py index 0feb1625..8f9bec58 100644 --- a/d3rlpy/dataset/trajectory_slicers.py +++ b/d3rlpy/dataset/trajectory_slicers.py @@ -116,6 +116,7 @@ class FrameStackTrajectorySlicer(TrajectorySlicerProtocol): Args: n_frames: Number of frames to stack. """ + _n_frames: int def __init__(self, n_frames: int): diff --git a/d3rlpy/dataset/transition_pickers.py b/d3rlpy/dataset/transition_pickers.py index 75bec9b0..bb502d46 100644 --- a/d3rlpy/dataset/transition_pickers.py +++ b/d3rlpy/dataset/transition_pickers.py @@ -103,6 +103,7 @@ class FrameStackTransitionPicker(TransitionPickerProtocol): n_frames (int): Number of frames to stack. gamma (float): Discount factor to compute return-to-go. """ + _n_frames: int _gamma: float @@ -152,6 +153,7 @@ class MultiStepTransitionPicker(TransitionPickerProtocol): ``net_observation``. gamma: Discount factor to compute a multi-step return. """ + _n_steps: int _gamma: float diff --git a/d3rlpy/dataset/utils.py b/d3rlpy/dataset/utils.py index fd342a78..403dfc1a 100644 --- a/d3rlpy/dataset/utils.py +++ b/d3rlpy/dataset/utils.py @@ -43,15 +43,13 @@ @overload -def retrieve_observation(observations: NDArray, index: int) -> NDArray: - ... +def retrieve_observation(observations: NDArray, index: int) -> NDArray: ... @overload def retrieve_observation( observations: Sequence[NDArray], index: int -) -> Sequence[NDArray]: - ... +) -> Sequence[NDArray]: ... def retrieve_observation( @@ -66,15 +64,13 @@ def retrieve_observation( @overload -def create_zero_observation(observation: NDArray) -> NDArray: - ... +def create_zero_observation(observation: NDArray) -> NDArray: ... @overload def create_zero_observation( observation: Sequence[NDArray], -) -> Sequence[NDArray]: - ... +) -> Sequence[NDArray]: ... def create_zero_observation(observation: Observation) -> Observation: @@ -87,15 +83,15 @@ def create_zero_observation(observation: Observation) -> Observation: @overload -def slice_observations(observations: NDArray, start: int, end: int) -> NDArray: - ... +def slice_observations( + observations: NDArray, start: int, end: int +) -> NDArray: ... @overload def slice_observations( observations: Sequence[NDArray], start: int, end: int -) -> Sequence[NDArray]: - ... +) -> Sequence[NDArray]: ... def slice_observations( @@ -123,15 +119,13 @@ def batch_pad_array( @overload -def batch_pad_observations(observations: NDArray, pad_size: int) -> NDArray: - ... +def batch_pad_observations(observations: NDArray, pad_size: int) -> NDArray: ... @overload def batch_pad_observations( observations: Sequence[NDArray], pad_size: int -) -> Sequence[NDArray]: - ... +) -> Sequence[NDArray]: ... def batch_pad_observations( @@ -151,15 +145,13 @@ def batch_pad_observations( @overload def stack_recent_observations( observations: NDArray, index: int, n_frames: int -) -> NDArray: - ... +) -> NDArray: ... @overload def stack_recent_observations( observations: Sequence[NDArray], index: int, n_frames: int -) -> Sequence[NDArray]: - ... +) -> Sequence[NDArray]: ... def stack_recent_observations( @@ -191,20 +183,17 @@ def squeeze_batch_dim(array: NDArray) -> NDArray: @overload -def stack_observations(observations: Sequence[NDArray]) -> NDArray: - ... +def stack_observations(observations: Sequence[NDArray]) -> NDArray: ... @overload def stack_observations( observations: Sequence[Sequence[NDArray]], -) -> Sequence[NDArray]: - ... +) -> Sequence[NDArray]: ... @overload -def stack_observations(observations: Sequence[Observation]) -> Observation: - ... +def stack_observations(observations: Sequence[Observation]) -> Observation: ... def stack_observations(observations: Sequence[Observation]) -> Observation: @@ -221,15 +210,13 @@ def stack_observations(observations: Sequence[Observation]) -> Observation: @overload -def get_shape_from_observation(observation: NDArray) -> Sequence[int]: - ... +def get_shape_from_observation(observation: NDArray) -> Sequence[int]: ... @overload def get_shape_from_observation( observation: Sequence[NDArray], -) -> Sequence[Sequence[int]]: - ... +) -> Sequence[Sequence[int]]: ... def get_shape_from_observation(observation: Observation) -> Shape: @@ -244,15 +231,13 @@ def get_shape_from_observation(observation: Observation) -> Shape: @overload def get_shape_from_observation_sequence( observations: NDArray, -) -> Sequence[int]: - ... +) -> Sequence[int]: ... @overload def get_shape_from_observation_sequence( observations: Sequence[NDArray], -) -> Sequence[Sequence[int]]: - ... +) -> Sequence[Sequence[int]]: ... def get_shape_from_observation_sequence( @@ -267,15 +252,13 @@ def get_shape_from_observation_sequence( @overload -def get_dtype_from_observation(observation: NDArray) -> DType: - ... +def get_dtype_from_observation(observation: NDArray) -> DType: ... @overload def get_dtype_from_observation( observation: Sequence[NDArray], -) -> Sequence[DType]: - ... +) -> Sequence[DType]: ... def get_dtype_from_observation( @@ -292,15 +275,13 @@ def get_dtype_from_observation( @overload def get_dtype_from_observation_sequence( observations: NDArray, -) -> DType: - ... +) -> DType: ... @overload def get_dtype_from_observation_sequence( observations: Sequence[NDArray], -) -> Sequence[DType]: - ... +) -> Sequence[DType]: ... def get_dtype_from_observation_sequence( @@ -335,15 +316,13 @@ def check_non_1d_array(array: Union[NDArray, Sequence[NDArray]]) -> bool: @overload def cast_recursively( array: NDArray, dtype: Type[_TDType] -) -> npt.NDArray[_TDType]: - ... +) -> npt.NDArray[_TDType]: ... @overload def cast_recursively( array: Sequence[NDArray], dtype: Type[_TDType] -) -> Sequence[npt.NDArray[_TDType]]: - ... +) -> Sequence[npt.NDArray[_TDType]]: ... def cast_recursively( diff --git a/d3rlpy/envs/wrappers.py b/d3rlpy/envs/wrappers.py index 905861a8..407811eb 100644 --- a/d3rlpy/envs/wrappers.py +++ b/d3rlpy/envs/wrappers.py @@ -188,6 +188,7 @@ class AtariPreprocessing(gym.Wrapper[NDArray, int]): FrameStack Wrapper. """ + _obs_buffer: Sequence[NDArray] def __init__( @@ -388,6 +389,7 @@ class GoalConcatWrapper( goal_key (str): String key of the goal observation. tuple_observation (bool): Flag to include goals as tuple element. """ + _observation_space: Union[GymnasiumBox, GymnasiumTuple] _observation_key: str _goal_key: str @@ -416,17 +418,21 @@ def __init__( goal_spaces = [goal_space[key] for key in goal_keys] goal_space_low = np.concatenate( [ - [space.low] * space.shape[0] # type: ignore - if np.isscalar(space.low) # type: ignore - else space.low # type: ignore + ( + [space.low] * space.shape[0] # type: ignore + if np.isscalar(space.low) # type: ignore + else space.low # type: ignore + ) for space in goal_spaces ] ) goal_space_high = np.concatenate( [ - [space.high] * space.shape[0] # type: ignore - if np.isscalar(space.high) # type: ignore - else space.high # type: ignore + ( + [space.high] * space.shape[0] # type: ignore + if np.isscalar(space.high) # type: ignore + else space.high # type: ignore + ) for space in goal_spaces ] ) @@ -446,9 +452,7 @@ def __init__( dtype=observation_space.dtype, # type: ignore ) - def step( - self, action: _ActType - ) -> Tuple[ + def step(self, action: _ActType) -> Tuple[ Union[NDArray, Tuple[NDArray, NDArray]], SupportsFloat, bool, diff --git a/d3rlpy/interface.py b/d3rlpy/interface.py index 5525f1d0..6ec9d728 100644 --- a/d3rlpy/interface.py +++ b/d3rlpy/interface.py @@ -9,39 +9,29 @@ class QLearningAlgoProtocol(Protocol): - def predict(self, x: Observation) -> NDArray: - ... + def predict(self, x: Observation) -> NDArray: ... - def predict_value(self, x: Observation, action: NDArray) -> NDArray: - ... + def predict_value(self, x: Observation, action: NDArray) -> NDArray: ... - def sample_action(self, x: Observation) -> NDArray: - ... + def sample_action(self, x: Observation) -> NDArray: ... @property - def gamma(self) -> float: - ... + def gamma(self) -> float: ... @property - def observation_scaler(self) -> Optional[ObservationScaler]: - ... + def observation_scaler(self) -> Optional[ObservationScaler]: ... @property - def action_scaler(self) -> Optional[ActionScaler]: - ... + def action_scaler(self) -> Optional[ActionScaler]: ... @property - def reward_scaler(self) -> Optional[RewardScaler]: - ... + def reward_scaler(self) -> Optional[RewardScaler]: ... @property - def action_size(self) -> Optional[int]: - ... + def action_size(self) -> Optional[int]: ... class StatefulTransformerAlgoProtocol(Protocol): - def predict(self, x: Observation, reward: float) -> Union[NDArray, int]: - ... + def predict(self, x: Observation, reward: float) -> Union[NDArray, int]: ... - def reset(self) -> None: - ... + def reset(self) -> None: ... diff --git a/d3rlpy/logging/file_adapter.py b/d3rlpy/logging/file_adapter.py index 8b9f34e0..21439294 100644 --- a/d3rlpy/logging/file_adapter.py +++ b/d3rlpy/logging/file_adapter.py @@ -32,6 +32,7 @@ class FileAdapter(LoggerAdapter): Args: logdir (str): Log directory. """ + _logdir: str def __init__(self, logdir: str): @@ -85,6 +86,7 @@ class FileAdapterFactory(LoggerAdapterFactory): Args: root_dir (str): Top-level log directory. """ + _root_dir: str def __init__(self, root_dir: str = "d3rlpy_logs"): diff --git a/d3rlpy/logging/logger.py b/d3rlpy/logging/logger.py index cd7d11dc..48b8cd0d 100644 --- a/d3rlpy/logging/logger.py +++ b/d3rlpy/logging/logger.py @@ -36,8 +36,7 @@ def set_log_context(**kwargs: Any) -> None: class SaveProtocol(Protocol): - def save(self, fname: str) -> None: - ... + def save(self, fname: str) -> None: ... class LoggerAdapter(Protocol): diff --git a/d3rlpy/logging/tensorboard_adapter.py b/d3rlpy/logging/tensorboard_adapter.py index ea8caf7f..35fe5632 100644 --- a/d3rlpy/logging/tensorboard_adapter.py +++ b/d3rlpy/logging/tensorboard_adapter.py @@ -21,6 +21,7 @@ class TensorboardAdapter(LoggerAdapter): root_dir (str): Top-level log directory. experiment_name (str): Experiment name. """ + _experiment_name: str _params: Dict[str, Any] _metrics: Dict[str, float] @@ -73,6 +74,7 @@ class TensorboardAdapterFactory(LoggerAdapterFactory): Args: root_dir (str): Top-level log directory. """ + _root_dir: str def __init__(self, root_dir: str = "tensorboard_logs"): diff --git a/d3rlpy/logging/utils.py b/d3rlpy/logging/utils.py index ff494207..48bb77e1 100644 --- a/d3rlpy/logging/utils.py +++ b/d3rlpy/logging/utils.py @@ -54,6 +54,7 @@ class CombineAdapterFactory(LoggerAdapterFactory): adapter_factories (Sequence[LoggerAdapterFactory]): List of LoggerAdapterFactory. """ + _adapter_factories: Sequence[LoggerAdapterFactory] def __init__(self, adapter_factories: Sequence[LoggerAdapterFactory]): diff --git a/d3rlpy/metrics/evaluators.py b/d3rlpy/metrics/evaluators.py index 05b1c095..15a267bc 100644 --- a/d3rlpy/metrics/evaluators.py +++ b/d3rlpy/metrics/evaluators.py @@ -85,6 +85,7 @@ class TDErrorEvaluator(EvaluatorProtocol): episodes: Optional evaluation episodes. If it's not given, dataset used in training will be used. """ + _episodes: Optional[Sequence[EpisodeBase]] def __init__(self, episodes: Optional[Sequence[EpisodeBase]] = None): @@ -145,6 +146,7 @@ class DiscountedSumOfAdvantageEvaluator(EvaluatorProtocol): episodes: Optional evaluation episodes. If it's not given, dataset used in training will be used. """ + _episodes: Optional[Sequence[EpisodeBase]] def __init__(self, episodes: Optional[Sequence[EpisodeBase]] = None): @@ -202,6 +204,7 @@ class AverageValueEstimationEvaluator(EvaluatorProtocol): episodes: Optional evaluation episodes. If it's not given, dataset used in training will be used. """ + _episodes: Optional[Sequence[EpisodeBase]] def __init__(self, episodes: Optional[Sequence[EpisodeBase]] = None): @@ -293,6 +296,7 @@ class SoftOPCEvaluator(EvaluatorProtocol): episodes: Optional evaluation episodes. If it's not given, dataset used in training will be used. """ + _return_threshold: float _episodes: Optional[Sequence[EpisodeBase]] @@ -340,6 +344,7 @@ class ContinuousActionDiffEvaluator(EvaluatorProtocol): episodes: Optional evaluation episodes. If it's not given, dataset used in training will be used. """ + _episodes: Optional[Sequence[EpisodeBase]] def __init__(self, episodes: Optional[Sequence[EpisodeBase]] = None): @@ -379,6 +384,7 @@ class DiscreteActionMatchEvaluator(EvaluatorProtocol): episodes: Optional evaluation episodes. If it's not given, dataset used in training will be used. """ + _episodes: Optional[Sequence[EpisodeBase]] def __init__(self, episodes: Optional[Sequence[EpisodeBase]] = None): @@ -419,6 +425,7 @@ class CompareContinuousActionDiffEvaluator(EvaluatorProtocol): episodes: Optional evaluation episodes. If it's not given, dataset used in training will be used. """ + _base_algo: QLearningAlgoProtocol _episodes: Optional[Sequence[EpisodeBase]] @@ -468,6 +475,7 @@ class CompareDiscreteActionMatchEvaluator(EvaluatorProtocol): episodes: Optional evaluation episodes. If it's not given, dataset used in training will be used. """ + _base_algo: QLearningAlgoProtocol _episodes: Optional[Sequence[EpisodeBase]] @@ -515,6 +523,7 @@ class EnvironmentEvaluator(EvaluatorProtocol): n_trials: Number of episodes to evaluate. epsilon: Probability of random action. """ + _env: GymEnv _n_trials: int _epsilon: float diff --git a/d3rlpy/models/encoders.py b/d3rlpy/models/encoders.py index fe642f95..315406c3 100644 --- a/d3rlpy/models/encoders.py +++ b/d3rlpy/models/encoders.py @@ -95,9 +95,11 @@ def create(self, observation_shape: Shape) -> PixelEncoder: dropout_rate=self.dropout_rate, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, - last_activation=create_activation(self.last_activation) - if self.last_activation - else None, + last_activation=( + create_activation(self.last_activation) + if self.last_activation + else None + ), ) def create_with_action( @@ -117,9 +119,11 @@ def create_with_action( discrete_action=discrete_action, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, - last_activation=create_activation(self.last_activation) - if self.last_activation - else None, + last_activation=( + create_activation(self.last_activation) + if self.last_activation + else None + ), ) @staticmethod @@ -160,9 +164,11 @@ def create(self, observation_shape: Shape) -> VectorEncoder: dropout_rate=self.dropout_rate, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, - last_activation=create_activation(self.last_activation) - if self.last_activation - else None, + last_activation=( + create_activation(self.last_activation) + if self.last_activation + else None + ), ) def create_with_action( @@ -181,9 +187,11 @@ def create_with_action( discrete_action=discrete_action, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, - last_activation=create_activation(self.last_activation) - if self.last_activation - else None, + last_activation=( + create_activation(self.last_activation) + if self.last_activation + else None + ), ) @staticmethod diff --git a/d3rlpy/models/torch/q_functions/ensemble_q_function.py b/d3rlpy/models/torch/q_functions/ensemble_q_function.py index 9e27b32b..9b444265 100644 --- a/d3rlpy/models/torch/q_functions/ensemble_q_function.py +++ b/d3rlpy/models/torch/q_functions/ensemble_q_function.py @@ -163,9 +163,11 @@ def compute_expected_q( values.append( value.view( 1, - x[0].shape[0] - if isinstance(x, (list, tuple)) - else x.shape[0], # type: ignore + ( + x[0].shape[0] + if isinstance(x, (list, tuple)) + else x.shape[0] # type: ignore + ), self._action_size, ) ) @@ -232,9 +234,11 @@ def compute_expected_q( values.append( value.view( 1, - x[0].shape[0] - if isinstance(x, (list, tuple)) - else x.shape[0], # type: ignore + ( + x[0].shape[0] + if isinstance(x, (list, tuple)) + else x.shape[0] # type: ignore + ), 1, ) ) diff --git a/d3rlpy/ope/fqe.py b/d3rlpy/ope/fqe.py index b7b8c76f..c4d553d9 100644 --- a/d3rlpy/ope/fqe.py +++ b/d3rlpy/ope/fqe.py @@ -61,6 +61,7 @@ class FQEConfig(LearnableConfig): action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. """ + learning_rate: float = 1e-4 optim_factory: OptimizerFactory = make_optimizer_field() encoder_factory: EncoderFactory = make_encoder_field() diff --git a/d3rlpy/preprocessing/action_scalers.py b/d3rlpy/preprocessing/action_scalers.py index 9c100f18..56d4f184 100644 --- a/d3rlpy/preprocessing/action_scalers.py +++ b/d3rlpy/preprocessing/action_scalers.py @@ -58,6 +58,7 @@ class MinMaxActionScaler(ActionScaler): minimum (numpy.ndarray): Minimum values at each entry. maximum (numpy.ndarray): Maximum values at each entry. """ + minimum: Optional[NDArray] = make_optional_numpy_field() maximum: Optional[NDArray] = make_optional_numpy_field() diff --git a/d3rlpy/preprocessing/observation_scalers.py b/d3rlpy/preprocessing/observation_scalers.py index aae7ef58..335e1b4b 100644 --- a/d3rlpy/preprocessing/observation_scalers.py +++ b/d3rlpy/preprocessing/observation_scalers.py @@ -119,6 +119,7 @@ class MinMaxObservationScaler(ObservationScaler): minimum (numpy.ndarray): Minimum values at each entry. maximum (numpy.ndarray): Maximum values at each entry. """ + minimum: Optional[NDArray] = make_optional_numpy_field() maximum: Optional[NDArray] = make_optional_numpy_field() @@ -265,6 +266,7 @@ class StandardObservationScaler(ObservationScaler): std (numpy.ndarray): Standard deviation at each entry. eps (float): Small constant value to avoid zero-division. """ + mean: Optional[NDArray] = make_optional_numpy_field() std: Optional[NDArray] = make_optional_numpy_field() eps: float = 1e-3 @@ -411,9 +413,9 @@ class TupleObservationScaler(ObservationScaler): List of observation scalers. """ - observation_scalers: Sequence[ - ObservationScaler - ] = observation_scaler_list_field() + observation_scalers: Sequence[ObservationScaler] = ( + observation_scaler_list_field() + ) def fit_with_transition_picker( self, diff --git a/d3rlpy/preprocessing/reward_scalers.py b/d3rlpy/preprocessing/reward_scalers.py index 9448397d..e3f752e3 100644 --- a/d3rlpy/preprocessing/reward_scalers.py +++ b/d3rlpy/preprocessing/reward_scalers.py @@ -49,6 +49,7 @@ class MultiplyRewardScaler(RewardScaler): Args: multiplier (float): Constant multiplication value. """ + multiplier: float = 1.0 def fit_with_transition_picker( @@ -104,6 +105,7 @@ class ClipRewardScaler(RewardScaler): high (Optional[float]): Maximum value to clip. multiplier (float): Constant multiplication value. """ + low: Optional[float] = None high: Optional[float] = None multiplier: float = 1.0 @@ -170,6 +172,7 @@ class MinMaxRewardScaler(RewardScaler): maximum (float): Maximum value. multiplier (float): Constant multiplication value. """ + minimum: Optional[float] = None maximum: Optional[float] = None multiplier: float = 1.0 @@ -262,6 +265,7 @@ class StandardRewardScaler(RewardScaler): eps (float): Constant value to avoid zero-division. multiplier (float): Constant multiplication value """ + mean: Optional[float] = None std: Optional[float] = None eps: float = 1e-3 @@ -359,6 +363,7 @@ class ReturnBasedRewardScaler(RewardScaler): return_min (float): Standard deviation value. multiplier (float): Constant multiplication value """ + return_max: Optional[float] = None return_min: Optional[float] = None multiplier: float = 1.0 @@ -448,6 +453,7 @@ class ConstantShiftRewardScaler(RewardScaler): Args: shift (float): Constant shift value """ + shift: float def fit_with_transition_picker( diff --git a/d3rlpy/tokenizers/tokenizers.py b/d3rlpy/tokenizers/tokenizers.py index 3a05ea6b..36c068e3 100644 --- a/d3rlpy/tokenizers/tokenizers.py +++ b/d3rlpy/tokenizers/tokenizers.py @@ -12,11 +12,9 @@ @runtime_checkable class Tokenizer(Protocol): - def __call__(self, x: NDArray) -> NDArray: - ... + def __call__(self, x: NDArray) -> NDArray: ... - def decode(self, y: Int32NDArray) -> NDArray: - ... + def decode(self, y: Int32NDArray) -> NDArray: ... class FloatTokenizer(Tokenizer): diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 0557943e..ccb9ee70 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -87,15 +87,15 @@ def convert_to_torch(array: NDArray, device: str) -> torch.Tensor: @overload -def convert_to_torch_recursively(array: NDArray, device: str) -> torch.Tensor: - ... +def convert_to_torch_recursively( + array: NDArray, device: str +) -> torch.Tensor: ... @overload def convert_to_torch_recursively( array: Sequence[NDArray], device: str -) -> Sequence[torch.Tensor]: - ... +) -> Sequence[torch.Tensor]: ... def convert_to_torch_recursively( diff --git a/tests/testing_utils.py b/tests/testing_utils.py index fa49a12e..646652c1 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -33,15 +33,13 @@ @overload def create_observation( observation_shape: Sequence[int], dtype: DType = np.float32 -) -> NDArray: - ... +) -> NDArray: ... @overload def create_observation( observation_shape: Sequence[Sequence[int]], dtype: DType = np.float32 -) -> Sequence[NDArray]: - ... +) -> Sequence[NDArray]: ... def create_observation( @@ -60,15 +58,13 @@ def create_observation( @overload def create_torch_observation( observation_shape: Sequence[int], dtype: DType = np.float32 -) -> torch.Tensor: - ... +) -> torch.Tensor: ... @overload def create_torch_observation( observation_shape: Sequence[Sequence[int]], dtype: DType = np.float32 -) -> Sequence[torch.Tensor]: - ... +) -> Sequence[torch.Tensor]: ... def create_torch_observation( @@ -82,8 +78,7 @@ def create_torch_observation( @overload def create_observations( observation_shape: Sequence[int], length: int, dtype: DType = np.float32 -) -> NDArray: - ... +) -> NDArray: ... @overload @@ -91,8 +86,7 @@ def create_observations( observation_shape: Sequence[Sequence[int]], length: int, dtype: DType = np.float32, -) -> Sequence[NDArray]: - ... +) -> Sequence[NDArray]: ... def create_observations( @@ -114,8 +108,7 @@ def create_observations( @overload def create_torch_observations( observation_shape: Sequence[int], length: int, dtype: DType = np.float32 -) -> torch.Tensor: - ... +) -> torch.Tensor: ... @overload @@ -123,8 +116,7 @@ def create_torch_observations( observation_shape: Sequence[Sequence[int]], length: int, dtype: DType = np.float32, -) -> Sequence[torch.Tensor]: - ... +) -> Sequence[torch.Tensor]: ... def create_torch_observations( @@ -141,8 +133,7 @@ def create_torch_batched_observations( batch_size: int, length: int, dtype: DType = np.float32, -) -> torch.Tensor: - ... +) -> torch.Tensor: ... @overload @@ -151,8 +142,7 @@ def create_torch_batched_observations( batch_size: int, length: int, dtype: DType = np.float32, -) -> Sequence[torch.Tensor]: - ... +) -> Sequence[torch.Tensor]: ... def create_torch_batched_observations(