Skip to content

Commit

Permalink
Fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 2, 2024
1 parent f6de602 commit 5209707
Show file tree
Hide file tree
Showing 37 changed files with 209 additions and 182 deletions.
8 changes: 5 additions & 3 deletions d3rlpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access
import random

import gymnasium
Expand Down Expand Up @@ -68,9 +69,10 @@ def seed(n: int) -> None:
# run healthcheck
run_healthcheck()

# enable autograd compilation
torch._dynamo.config.compiled_autograd = True
torch.set_float32_matmul_precision("high")
if torch.cuda.is_available():
# enable autograd compilation
torch._dynamo.config.compiled_autograd = True
torch.set_float32_matmul_precision("high")

# register Shimmy if available
try:
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class AWACConfig(LearnableConfig):
n_action_samples (int): Number of sampled actions to calculate
:math:`A^\pi(s_t, a_t)`.
n_critics (int): Number of Q functions for ensemble.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand All @@ -86,7 +86,7 @@ class AWACConfig(LearnableConfig):
lam: float = 1.0
n_action_samples: int = 1
n_critics: int = 2
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -160,7 +160,7 @@ def inner_create_impl(
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
12 changes: 6 additions & 6 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class BCQConfig(LearnableConfig):
rl_start_step (int): Steps to start to update policy function and Q
functions. If this is large, RL training would be more stabilized.
beta (float): KL reguralization term for Conditional VAE.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-3
Expand All @@ -160,7 +160,7 @@ class BCQConfig(LearnableConfig):
action_flexibility: float = 0.05
rl_start_step: int = 0
beta: float = 0.5
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -266,7 +266,7 @@ def inner_create_impl(
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
rl_start_step=self._config.rl_start_step,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down Expand Up @@ -334,7 +334,7 @@ class DiscreteBCQConfig(LearnableConfig):
target_update_interval (int): Interval to update the target network.
share_encoder (bool): Flag to share encoder between Q-function and
imitation models.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand All @@ -348,7 +348,7 @@ class DiscreteBCQConfig(LearnableConfig):
beta: float = 0.5
target_update_interval: int = 8000
share_encoder: bool = True
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -427,7 +427,7 @@ def inner_create_impl(
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class BEARConfig(LearnableConfig):
policy training.
warmup_steps (int): Number of steps to warmup the policy
function.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand Down Expand Up @@ -146,7 +146,7 @@ class BEARConfig(LearnableConfig):
mmd_sigma: float = 20.0
vae_kl_weight: float = 0.5
warmup_steps: int = 40000
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -268,7 +268,7 @@ def inner_create_impl(
mmd_sigma=self._config.mmd_sigma,
vae_kl_weight=self._config.vae_kl_weight,
warmup_steps=self._config.warmup_steps,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class CalQLConfig(CQLConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

def create(
Expand Down Expand Up @@ -172,7 +172,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile=self._config.compile,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
12 changes: 6 additions & 6 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class CQLConfig(LearnableConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand All @@ -125,7 +125,7 @@ class CQLConfig(LearnableConfig):
n_action_samples: int = 10
soft_q_backup: bool = False
max_q_backup: bool = False
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -227,7 +227,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down Expand Up @@ -275,7 +275,7 @@ class DiscreteCQLConfig(LearnableConfig):
target_update_interval (int): Interval to synchronize the target
network.
alpha (float): math:`\alpha` value above.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand All @@ -287,7 +287,7 @@ class DiscreteCQLConfig(LearnableConfig):
n_critics: int = 1
target_update_interval: int = 8000
alpha: float = 1.0
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -341,7 +341,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
alpha=self._config.alpha,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class CRRConfig(LearnableConfig):
``soft`` target update.
update_actor_interval (int): Interval to update policy function used
with ``hard`` target update.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand All @@ -120,6 +121,7 @@ class CRRConfig(LearnableConfig):
tau: float = 5e-3
target_update_interval: int = 100
update_actor_interval: int = 1
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -199,6 +201,7 @@ def inner_create_impl(
tau=self._config.tau,
target_update_type=self._config.target_update_type,
target_update_interval=self._config.target_update_interval,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DDPGConfig(LearnableConfig):
gamma (float): Discount factor.
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 256
Expand All @@ -82,7 +82,7 @@ class DDPGConfig(LearnableConfig):
q_func_factory: QFunctionFactory = make_q_func_field()
tau: float = 0.005
n_critics: int = 1
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -155,7 +155,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
compile=self._config.compile,
compile_graph=self._config.compile_graph,
device=self._device,
)

Expand Down
12 changes: 6 additions & 6 deletions d3rlpy/algos/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DQNConfig(LearnableConfig):
gamma (float): Discount factor.
n_critics (int): Number of Q functions for ensemble.
target_update_interval (int): Interval to update the target network.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 32
Expand All @@ -55,7 +55,7 @@ class DQNConfig(LearnableConfig):
gamma: float = 0.99
n_critics: int = 1
target_update_interval: int = 8000
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -108,7 +108,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
modules=modules,
gamma=self._config.gamma,
compile=self._config.compile,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down Expand Up @@ -154,7 +154,7 @@ class DoubleDQNConfig(DQNConfig):
n_critics (int): Number of Q functions.
target_update_interval (int): Interval to synchronize the target
network.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 32
Expand All @@ -165,7 +165,7 @@ class DoubleDQNConfig(DQNConfig):
gamma: float = 0.99
n_critics: int = 1
target_update_interval: int = 8000
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -218,7 +218,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_forwarder,
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
compile=self._config.compile,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class IQLConfig(LearnableConfig):
weight_temp (float): Inverse temperature value represented as
:math:`\beta`.
max_weight (float): Maximum advantage weight value to clip.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand All @@ -97,7 +97,7 @@ class IQLConfig(LearnableConfig):
expectile: float = 0.7
weight_temp: float = 3.0
max_weight: float = 100.0
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -177,7 +177,7 @@ def inner_create_impl(
expectile=self._config.expectile,
weight_temp=self._config.weight_temp,
max_weight=self._config.max_weight,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/nfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class NFQConfig(LearnableConfig):
batch_size (int): Mini-batch size.
gamma (float): Discount factor.
n_critics (int): Number of Q functions for ensemble.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand All @@ -56,6 +57,7 @@ class NFQConfig(LearnableConfig):
batch_size: int = 32
gamma: float = 0.99
n_critics: int = 1
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -108,6 +110,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
target_update_interval=1,
gamma=self._config.gamma,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class PLASConfig(LearnableConfig):
lam (float): Weight factor for critic ensemble.
warmup_steps (int): Number of steps to warmup the VAE.
beta (float): KL reguralization term for Conditional VAE.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand All @@ -97,7 +97,7 @@ class PLASConfig(LearnableConfig):
lam: float = 0.75
warmup_steps: int = 500000
beta: float = 0.5
compile: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -199,7 +199,7 @@ def inner_create_impl(
lam=self._config.lam,
beta=self._config.beta,
warmup_steps=self._config.warmup_steps,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down Expand Up @@ -250,7 +250,7 @@ class PLASWithPerturbationConfig(PLASConfig):
action_flexibility (float): Output scale of perturbation layer.
warmup_steps (int): Number of steps to warmup the VAE.
beta (float): KL reguralization term for Conditional VAE.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

action_flexibility: float = 0.05
Expand Down Expand Up @@ -377,7 +377,7 @@ def inner_create_impl(
lam=self._config.lam,
beta=self._config.beta,
warmup_steps=self._config.warmup_steps,
compile=self._config.compile and "cuda" in self._device,
compile_graph=self._config.compile_graph and "cuda" in self._device,
device=self._device,
)

Expand Down
Loading

0 comments on commit 5209707

Please sign in to comment.