Skip to content

Commit

Permalink
Support CudaGraph and torch.compile (#428)
Browse files Browse the repository at this point in the history
* Add CudaGraphWrapper

* Fix lint errors

* Fix TD3

* Fix DiscreteSAC

* Update torch dependency

* Fix lint error

* Add compiled flag to OptimizerWrapper

* Workaround DiscreteSAC test

* Add compiled property

* Add tests

* Update python version in readthedocs

* Remove unnecessary change

* Rename compile_graph to compiled

* Support BC

* Add compile option to reproduction scripts
  • Loading branch information
takuseno authored Nov 3, 2024
1 parent 3b01da3 commit f3c5540
Show file tree
Hide file tree
Showing 72 changed files with 824 additions and 276 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "3.8"
python: "3.10"
sphinx:
builder: html
configuration: docs/conf.py
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ d3rlpy supports Linux, macOS and Windows.

### Dependencies
Installing d3rlpy package will install or upgrade the following packages to satisfy requirements:
- torch>=2.0.0
- torch>=2.5.0
- tqdm>=4.66.3
- gym>=0.26.0
- gymnasium>=1.0.0
Expand Down
5 changes: 5 additions & 0 deletions d3rlpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access
import random

import gymnasium
Expand Down Expand Up @@ -68,6 +69,10 @@ def seed(n: int) -> None:
# run healthcheck
run_healthcheck()

if torch.cuda.is_available():
# enable autograd compilation
torch._dynamo.config.compiled_autograd = True
torch.set_float32_matmul_precision("high")

# register Shimmy if available
try:
Expand Down
10 changes: 8 additions & 2 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class AWACConfig(LearnableConfig):
n_action_samples (int): Number of sampled actions to calculate
:math:`A^\pi(s_t, a_t)`.
n_critics (int): Number of Q functions for ensemble.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand Down Expand Up @@ -130,10 +131,14 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)

dummy_log_temp = Parameter(torch.zeros(1, 1))
Expand All @@ -158,6 +163,7 @@ def inner_create_impl(
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compiled=self.compiled,
device=self._device,
)

Expand Down
12 changes: 10 additions & 2 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BCConfig(LearnableConfig):
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 100
Expand Down Expand Up @@ -93,7 +94,9 @@ def inner_create_impl(
raise ValueError(f"invalid policy_type: {self._config.policy_type}")

optim = self._config.optim_factory.create(
imitator.named_modules(), lr=self._config.learning_rate
imitator.named_modules(),
lr=self._config.learning_rate,
compiled=self.compiled,
)

modules = BCModules(optim=optim, imitator=imitator)
Expand All @@ -103,6 +106,7 @@ def inner_create_impl(
action_size=action_size,
modules=modules,
policy_type=self._config.policy_type,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -137,6 +141,7 @@ class DiscreteBCConfig(LearnableConfig):
beta (float): Reguralization factor.
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 100
Expand Down Expand Up @@ -168,7 +173,9 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
imitator.named_modules(), lr=self._config.learning_rate
imitator.named_modules(),
lr=self._config.learning_rate,
compiled=self.compiled,
)

modules = DiscreteBCModules(optim=optim, imitator=imitator)
Expand All @@ -178,6 +185,7 @@ def inner_create_impl(
action_size=action_size,
modules=modules,
beta=self._config.beta,
compiled=self.compiled,
device=self._device,
)

Expand Down
17 changes: 14 additions & 3 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class BCQConfig(LearnableConfig):
rl_start_step (int): Steps to start to update policy function and Q
functions. If this is large, RL training would be more stabilized.
beta (float): KL reguralization term for Conditional VAE.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-3
Expand Down Expand Up @@ -228,15 +229,20 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
compiled=self.compiled,
)

modules = BCQModules(
Expand Down Expand Up @@ -264,6 +270,7 @@ def inner_create_impl(
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
rl_start_step=self._config.rl_start_step,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -331,6 +338,7 @@ class DiscreteBCQConfig(LearnableConfig):
target_update_interval (int): Interval to update the target network.
share_encoder (bool): Flag to share encoder between Q-function and
imitation models.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand Down Expand Up @@ -402,7 +410,9 @@ def inner_create_impl(
q_func_params = list(q_funcs.named_modules())
imitator_params = list(imitator.named_modules())
optim = self._config.optim_factory.create(
q_func_params + imitator_params, lr=self._config.learning_rate
q_func_params + imitator_params,
lr=self._config.learning_rate,
compiled=self.compiled,
)

modules = DiscreteBCQModules(
Expand All @@ -422,6 +432,7 @@ def inner_create_impl(
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compiled=self.compiled,
device=self._device,
)

Expand Down
19 changes: 15 additions & 4 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class BEARConfig(LearnableConfig):
policy training.
warmup_steps (int): Number of steps to warmup the policy
function.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand Down Expand Up @@ -217,21 +218,30 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
compiled=self.compiled,
)
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=self.compiled,
)
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.actor_learning_rate
log_alpha.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)

modules = BEARModules(
Expand Down Expand Up @@ -266,6 +276,7 @@ def inner_create_impl(
mmd_sigma=self._config.mmd_sigma,
vae_kl_weight=self._config.vae_kl_weight,
warmup_steps=self._config.warmup_steps,
compiled=self.compiled,
device=self._device,
)

Expand Down
19 changes: 14 additions & 5 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CalQLConfig(CQLConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

def create(
Expand All @@ -88,7 +89,6 @@ def inner_create_impl(
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -128,20 +128,28 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=self.compiled,
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
log_alpha.named_modules(),
lr=self._config.alpha_learning_rate,
compiled=self.compiled,
)
else:
alpha_optim = None
Expand Down Expand Up @@ -171,6 +179,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compiled=self.compiled,
device=self._device,
)

Expand Down
25 changes: 19 additions & 6 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class CQLConfig(LearnableConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand Down Expand Up @@ -142,7 +143,6 @@ def inner_create_impl(
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -182,20 +182,28 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=self.compiled,
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
log_alpha.named_modules(),
lr=self._config.alpha_learning_rate,
compiled=self.compiled,
)
else:
alpha_optim = None
Expand Down Expand Up @@ -225,6 +233,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -272,6 +281,7 @@ class DiscreteCQLConfig(LearnableConfig):
target_update_interval (int): Interval to synchronize the target
network.
alpha (float): math:`\alpha` value above.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand Down Expand Up @@ -318,7 +328,9 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
q_funcs.named_modules(), lr=self._config.learning_rate
q_funcs.named_modules(),
lr=self._config.learning_rate,
compiled=self.compiled,
)

modules = DQNModules(
Expand All @@ -336,6 +348,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
alpha=self._config.alpha,
compiled=self.compiled,
device=self._device,
)

Expand Down
Loading

0 comments on commit f3c5540

Please sign in to comment.