diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index b3033508..bed815d6 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -91,7 +91,7 @@ def __init__( self._beta = beta self._rl_start_step = rl_start_step self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + CudaGraphWrapper(self.compute_imitator_grad) if compile_graph else self.compute_imitator_grad ) diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 3358d487..7aea6217 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -116,12 +116,12 @@ def __init__( self._vae_kl_weight = vae_kl_weight self._warmup_steps = warmup_steps self._compute_warmup_actor_grad = ( - CudaGraphWrapper(self.compute_warmup_actor_grad) # type: ignore + CudaGraphWrapper(self.compute_warmup_actor_grad) if compile_graph else self.compute_warmup_actor_grad ) self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + CudaGraphWrapper(self.compute_imitator_grad) if compile_graph else self.compute_imitator_grad ) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 3d88dcfd..7a427147 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -87,12 +87,12 @@ def __init__( self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder self._compute_critic_grad = ( - CudaGraphWrapper(self.compute_critic_grad) # type: ignore + CudaGraphWrapper(self.compute_critic_grad) if compile_graph else self.compute_critic_grad ) self._compute_actor_grad = ( - CudaGraphWrapper(self.compute_actor_grad) # type: ignore + CudaGraphWrapper(self.compute_actor_grad) if compile_graph else self.compute_actor_grad ) diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index 0ecfaee9..e7835ff6 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -64,7 +64,7 @@ def __init__( self._targ_q_func_forwarder = targ_q_func_forwarder self._target_update_interval = target_update_interval self._compute_grad = ( - CudaGraphWrapper(self.compute_grad) # type: ignore + CudaGraphWrapper(self.compute_grad) if compile_graph else self.compute_grad ) diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 71620cdf..1d60010e 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -71,7 +71,7 @@ def __init__( self._beta = beta self._warmup_steps = warmup_steps self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + CudaGraphWrapper(self.compute_imitator_grad) if compile_graph else self.compute_imitator_grad ) diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index c4283c05..b9521e63 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -174,12 +174,12 @@ def __init__( self._targ_q_func_forwarder = targ_q_func_forwarder self._target_update_interval = target_update_interval self._compute_critic_grad = ( - CudaGraphWrapper(self.compute_critic_grad) # type: ignore + CudaGraphWrapper(self.compute_critic_grad) if compile_graph else self.compute_critic_grad ) self._compute_actor_grad = ( - CudaGraphWrapper(self.compute_actor_grad) # type: ignore + CudaGraphWrapper(self.compute_actor_grad) if compile_graph else self.compute_actor_grad ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index 8fb2a7dd..2b03c0b2 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -47,7 +47,7 @@ def __init__( ): super().__init__(observation_shape, action_size, modules, device) self._compute_grad = ( - CudaGraphWrapper(self.compute_grad) # type: ignore + CudaGraphWrapper(self.compute_grad) if compile_graph else self.compute_grad ) @@ -122,7 +122,7 @@ def __init__( self._final_tokens = final_tokens self._initial_learning_rate = initial_learning_rate self._compute_grad = ( - CudaGraphWrapper(self.compute_grad) # type: ignore + CudaGraphWrapper(self.compute_grad) if compile_graph else self.compute_grad )