Skip to content

Commit

Permalink
Apply upgraded black format
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 18, 2024
1 parent cc15df8 commit 8cb7c4d
Show file tree
Hide file tree
Showing 42 changed files with 165 additions and 131 deletions.
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class AWACConfig(LearnableConfig):
:math:`A^\pi(s_t, a_t)`.
n_critics (int): Number of Q functions for ensemble.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
actor_optim_factory: OptimizerFactory = make_optimizer_field()
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class BCConfig(LearnableConfig):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
"""

batch_size: int = 100
learning_rate: float = 1e-3
policy_type: str = "deterministic"
Expand Down Expand Up @@ -133,6 +134,7 @@ class DiscreteBCConfig(LearnableConfig):
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
"""

batch_size: int = 100
learning_rate: float = 1e-3
optim_factory: OptimizerFactory = make_optimizer_field()
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class BCQConfig(LearnableConfig):
functions. If this is large, RL training would be more stabilized.
beta (float): KL reguralization term for Conditional VAE.
"""

actor_learning_rate: float = 1e-3
critic_learning_rate: float = 1e-3
imitator_learning_rate: float = 1e-3
Expand Down Expand Up @@ -323,6 +324,7 @@ class DiscreteBCQConfig(LearnableConfig):
share_encoder (bool): Flag to share encoder between Q-function and
imitation models.
"""

learning_rate: float = 6.25e-5
optim_factory: OptimizerFactory = make_optimizer_field()
encoder_factory: EncoderFactory = make_encoder_field()
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class BEARConfig(LearnableConfig):
warmup_steps (int): Number of steps to warmup the policy
function.
"""

actor_learning_rate: float = 1e-4
critic_learning_rate: float = 3e-4
imitator_learning_rate: float = 3e-4
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class CQLConfig(LearnableConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
"""

actor_learning_rate: float = 1e-4
critic_learning_rate: float = 3e-4
temp_learning_rate: float = 1e-4
Expand Down Expand Up @@ -256,6 +257,7 @@ class DiscreteCQLConfig(LearnableConfig):
network.
alpha (float): math:`\alpha` value above.
"""

learning_rate: float = 6.25e-5
optim_factory: OptimizerFactory = make_optimizer_field()
encoder_factory: EncoderFactory = make_encoder_field()
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class CRRConfig(LearnableConfig):
update_actor_interval (int): Interval to update policy function used
with ``hard`` target update.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
actor_optim_factory: OptimizerFactory = make_optimizer_field()
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class DDPGConfig(LearnableConfig):
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
"""

batch_size: int = 256
actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/algos/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DQNConfig(LearnableConfig):
n_critics (int): Number of Q functions for ensemble.
target_update_interval (int): Interval to update the target network.
"""

batch_size: int = 32
learning_rate: float = 6.25e-5
optim_factory: OptimizerFactory = make_optimizer_field()
Expand Down Expand Up @@ -147,6 +148,7 @@ class DoubleDQNConfig(DQNConfig):
target_update_interval (int): Interval to synchronize the target
network.
"""

batch_size: int = 32
learning_rate: float = 6.25e-5
optim_factory: OptimizerFactory = make_optimizer_field()
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class IQLConfig(LearnableConfig):
:math:`\beta`.
max_weight (float): Maximum advantage weight value to clip.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
actor_optim_factory: OptimizerFactory = make_optimizer_field()
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/nfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class NFQConfig(LearnableConfig):
gamma (float): Discount factor.
n_critics (int): Number of Q functions for ensemble.
"""

learning_rate: float = 6.25e-5
optim_factory: OptimizerFactory = make_optimizer_field()
encoder_factory: EncoderFactory = make_encoder_field()
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class PLASConfig(LearnableConfig):
warmup_steps (int): Number of steps to warmup the VAE.
beta (float): KL reguralization term for Conditional VAE.
"""

actor_learning_rate: float = 1e-4
critic_learning_rate: float = 1e-3
imitator_learning_rate: float = 1e-4
Expand Down Expand Up @@ -239,6 +240,7 @@ class PLASWithPerturbationConfig(PLASConfig):
warmup_steps (int): Number of steps to warmup the VAE.
beta (float): KL reguralization term for Conditional VAE.
"""

action_flexibility: float = 0.05

def create(self, device: DeviceArg = False) -> "PLASWithPerturbation":
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/random_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class RandomPolicyConfig(LearnableConfig):
normal_std (float): Standard deviation of the normal distribution. This
is only used when ``distribution='normal'``.
"""

distribution: str = "uniform"
normal_std: float = 1.0

Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/algos/qlearning/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class SACConfig(LearnableConfig):
n_critics (int): Number of Q functions for ensemble.
initial_temperature (float): Initial temperature value.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
temp_learning_rate: float = 3e-4
Expand Down Expand Up @@ -243,6 +244,7 @@ class DiscreteSACConfig(LearnableConfig):
n_critics (int): Number of Q functions for ensemble.
initial_temperature (float): Initial temperature value.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
temp_learning_rate: float = 3e-4
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TD3Config(LearnableConfig):
update_actor_interval (int): Interval to update policy function
described as `delayed policy update` in the paper.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
actor_optim_factory: OptimizerFactory = make_optimizer_field()
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/td3_plus_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class TD3PlusBCConfig(LearnableConfig):
update_actor_interval (int): Interval to update policy function
described as `delayed policy update` in the paper.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
actor_optim_factory: OptimizerFactory = make_optimizer_field()
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/transformer/action_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class SoftmaxTransformerActionSampler(TransformerActionSampler):
Args:
temperature (int): Softmax temperature.
"""

_temperature: float

def __init__(self, temperature: float = 1.0):
Expand Down
5 changes: 3 additions & 2 deletions d3rlpy/algos/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def predict(self, inpt: TorchTransformerInput) -> torch.Tensor:

@abstractmethod
def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
...
raise NotImplementedError

@train_api
def update(
Expand All @@ -58,7 +58,7 @@ def update(
def inner_update(
self, batch: TorchTrajectoryMiniBatch, grad_step: int
) -> Dict[str, float]:
pass
raise NotImplementedError


@dataclasses.dataclass()
Expand Down Expand Up @@ -101,6 +101,7 @@ class StatefulTransformerWrapper(Generic[TTransformerImpl, TTransformerConfig]):
target_return (float): Target return.
action_sampler (d3rlpy.algos.TransformerActionSampler): Action sampler.
"""

_algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]"
_target_return: float
_action_sampler: TransformerActionSampler
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def unwrap_models_by_ddp(self) -> None:
class LearnableConfig(DynamicConfig):
batch_size: int = 256
gamma: float = 0.99
observation_scaler: Optional[
ObservationScaler
] = make_observation_scaler_field()
observation_scaler: Optional[ObservationScaler] = (
make_observation_scaler_field()
)
action_scaler: Optional[ActionScaler] = make_action_scaler_field()
reward_scaler: Optional[RewardScaler] = make_reward_scaler_field()

Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/dataset/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __getitem__(self, index: int) -> Tuple[EpisodeBase, int]:

class InfiniteBuffer(BufferProtocol):
r"""Buffer with unlimited capacity."""

_transitions: List[Tuple[EpisodeBase, int]]
_episodes: List[EpisodeBase]

Expand Down Expand Up @@ -78,6 +79,7 @@ class FIFOBuffer(BufferProtocol):
Args:
limit (int): buffer capacity.
"""

_transitions: Deque[Tuple[EpisodeBase, int]]
_episodes: List[EpisodeBase]
_limit: int
Expand Down
5 changes: 5 additions & 0 deletions d3rlpy/dataset/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Signature:
dtype: List of numpy data types.
shape: List of array shapes.
"""

dtype: Sequence[DType]
shape: Sequence[Sequence[int]]

Expand Down Expand Up @@ -69,6 +70,7 @@ class Transition:
terminal: Flag of environment termination.
interval: Timesteps between ``observation`` and ``next_observation``.
"""

observation: Observation # (...)
action: NDArray # (...)
reward: Float32NDArray # (1,)
Expand Down Expand Up @@ -130,6 +132,7 @@ class PartialTrajectory:
masks: Sequence of masks that represent padding.
length: Sequence length.
"""

observations: ObservationSequence # (L, ...)
actions: NDArray # (L, ...)
rewards: Float32NDArray # (L, 1)
Expand Down Expand Up @@ -350,6 +353,7 @@ class Episode:
rewards: Sequence of rewards.
terminated: Flag of environment termination.
"""

observations: ObservationSequence
actions: NDArray
rewards: Float32NDArray
Expand Down Expand Up @@ -422,6 +426,7 @@ class DatasetInfo:
this represents dimension of action vectors. For discrete
action-space, this represents the number of discrete actions.
"""

observation_signature: Signature
action_signature: Signature
reward_signature: Signature
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/dataset/episode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class EpisodeGenerator(EpisodeGeneratorProtocol):
terminals: Sequence of environment terminal flags.
timeouts: Sequence of timeout flags.
"""

_observations: ObservationSequence
_actions: NDArray
_rewards: Float32NDArray
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/dataset/mini_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TransitionMiniBatch:
intervals: Batched timesteps between observations and next
observations.
"""

observations: Union[Float32NDArray, Sequence[Float32NDArray]] # (B, ...)
actions: Float32NDArray # (B, ...)
rewards: Float32NDArray # (B, 1)
Expand Down Expand Up @@ -146,6 +147,7 @@ class TrajectoryMiniBatch:
masks: Batched masks that represent padding.
length: Length of trajectories.
"""

observations: Union[Float32NDArray, Sequence[Float32NDArray]] # (B, L, ...)
actions: Float32NDArray # (B, L, ...)
rewards: Float32NDArray # (B, L, 1)
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/dataset/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ class ReplayBuffer(ReplayBufferBase):
for online training. ``cache_size`` needs to be greater than the
maximum possible episode length.
"""

_buffer: BufferProtocol
_transition_picker: TransitionPickerProtocol
_trajectory_slicer: TrajectorySlicerProtocol
Expand Down Expand Up @@ -587,6 +588,7 @@ class MixedReplayBuffer(ReplayBufferBase):
secondary_mix_ratio (float): Ratio to sample mini-batches from the
secondary replay buffer.
"""

_primary_replay_buffer: ReplayBufferBase
_secondary_replay_buffer: ReplayBufferBase
_secondary_mix_ratio: float
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/dataset/trajectory_slicers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class FrameStackTrajectorySlicer(TrajectorySlicerProtocol):
Args:
n_frames: Number of frames to stack.
"""

_n_frames: int

def __init__(self, n_frames: int):
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/dataset/transition_pickers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class FrameStackTransitionPicker(TransitionPickerProtocol):
n_frames (int): Number of frames to stack.
gamma (float): Discount factor to compute return-to-go.
"""

_n_frames: int
_gamma: float

Expand Down Expand Up @@ -152,6 +153,7 @@ class MultiStepTransitionPicker(TransitionPickerProtocol):
``net_observation``.
gamma: Discount factor to compute a multi-step return.
"""

_n_steps: int
_gamma: float

Expand Down
Loading

0 comments on commit 8cb7c4d

Please sign in to comment.