diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index adb886ee..da9cc275 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -49,6 +49,7 @@ class BCConfig(LearnableConfig): observation_scaler (d3rlpy.preprocessing.ObservationScaler): Observation preprocessor. action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 100 @@ -95,7 +96,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( imitator.named_modules(), lr=self._config.learning_rate, - compiled=False, + compiled=self.compiled, ) modules = BCModules(optim=optim, imitator=imitator) @@ -105,6 +106,7 @@ def inner_create_impl( action_size=action_size, modules=modules, policy_type=self._config.policy_type, + compiled=self.compiled, device=self._device, ) @@ -139,6 +141,7 @@ class DiscreteBCConfig(LearnableConfig): beta (float): Reguralization factor. observation_scaler (d3rlpy.preprocessing.ObservationScaler): Observation preprocessor. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 100 @@ -172,7 +175,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( imitator.named_modules(), lr=self._config.learning_rate, - compiled=False, + compiled=self.compiled, ) modules = DiscreteBCModules(optim=optim, imitator=imitator) @@ -182,6 +185,7 @@ def inner_create_impl( action_size=action_size, modules=modules, beta=self._config.beta, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 8680ef5b..85b1d5f6 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -1,6 +1,6 @@ import dataclasses from abc import ABCMeta, abstractmethod -from typing import Dict, Union +from typing import Callable, Dict, Union import torch from torch.optim import Optimizer @@ -18,7 +18,7 @@ compute_stochastic_imitation_loss, ) from ....optimizers import OptimizerWrapper -from ....torch_utility import Modules, TorchMiniBatch +from ....torch_utility import CudaGraphWrapper, Modules, TorchMiniBatch from ....types import Shape, TorchObservation from ..base import QLearningAlgoImplBase @@ -32,12 +32,14 @@ class BCBaseModules(Modules): class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta): _modules: BCBaseModules + _compute_imitator_grad: Callable[[TorchMiniBatch], ImitationLoss] def __init__( self, observation_shape: Shape, action_size: int, modules: BCBaseModules, + compiled: bool, device: str, ): super().__init__( @@ -46,15 +48,21 @@ def __init__( modules=modules, device=device, ) + self._compute_imitator_grad = ( + CudaGraphWrapper(self.compute_imitator_grad) + if compiled + else self.compute_imitator_grad + ) - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss: self._modules.optim.zero_grad() - loss = self.compute_loss(batch.observations, batch.actions) - loss.loss.backward() - self._modules.optim.step() + return loss + def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + loss = self._compute_imitator_grad(batch) + self._modules.optim.step() return asdict_as_float(loss) @abstractmethod @@ -92,12 +100,14 @@ def __init__( action_size: int, modules: BCModules, policy_type: str, + compiled: bool, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, modules=modules, + compiled=compiled, device=device, ) self._policy_type = policy_type @@ -145,12 +155,14 @@ def __init__( action_size: int, modules: DiscreteBCModules, beta: float, + compiled: bool, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, modules=modules, + compiled=compiled, device=device, ) self._beta = beta