Skip to content

Commit

Permalink
Support BC
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent b41be5a commit c6e6cd4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
8 changes: 6 additions & 2 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BCConfig(LearnableConfig):
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 100
Expand Down Expand Up @@ -95,7 +96,7 @@ def inner_create_impl(
optim = self._config.optim_factory.create(
imitator.named_modules(),
lr=self._config.learning_rate,
compiled=False,
compiled=self.compiled,
)

modules = BCModules(optim=optim, imitator=imitator)
Expand All @@ -105,6 +106,7 @@ def inner_create_impl(
action_size=action_size,
modules=modules,
policy_type=self._config.policy_type,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -139,6 +141,7 @@ class DiscreteBCConfig(LearnableConfig):
beta (float): Reguralization factor.
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 100
Expand Down Expand Up @@ -172,7 +175,7 @@ def inner_create_impl(
optim = self._config.optim_factory.create(
imitator.named_modules(),
lr=self._config.learning_rate,
compiled=False,
compiled=self.compiled,
)

modules = DiscreteBCModules(optim=optim, imitator=imitator)
Expand All @@ -182,6 +185,7 @@ def inner_create_impl(
action_size=action_size,
modules=modules,
beta=self._config.beta,
compiled=self.compiled,
device=self._device,
)

Expand Down
24 changes: 18 additions & 6 deletions d3rlpy/algos/qlearning/torch/bc_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Dict, Union
from typing import Callable, Dict, Union

import torch
from torch.optim import Optimizer
Expand All @@ -18,7 +18,7 @@
compute_stochastic_imitation_loss,
)
from ....optimizers import OptimizerWrapper
from ....torch_utility import Modules, TorchMiniBatch
from ....torch_utility import CudaGraphWrapper, Modules, TorchMiniBatch
from ....types import Shape, TorchObservation
from ..base import QLearningAlgoImplBase

Expand All @@ -32,12 +32,14 @@ class BCBaseModules(Modules):

class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta):
_modules: BCBaseModules
_compute_imitator_grad: Callable[[TorchMiniBatch], ImitationLoss]

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: BCBaseModules,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -46,15 +48,21 @@ def __init__(
modules=modules,
device=device,
)
self._compute_imitator_grad = (
CudaGraphWrapper(self.compute_imitator_grad)
if compiled
else self.compute_imitator_grad
)

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss:
self._modules.optim.zero_grad()

loss = self.compute_loss(batch.observations, batch.actions)

loss.loss.backward()
self._modules.optim.step()
return loss

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
loss = self._compute_imitator_grad(batch)
self._modules.optim.step()
return asdict_as_float(loss)

@abstractmethod
Expand Down Expand Up @@ -92,12 +100,14 @@ def __init__(
action_size: int,
modules: BCModules,
policy_type: str,
compiled: bool,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
compiled=compiled,
device=device,
)
self._policy_type = policy_type
Expand Down Expand Up @@ -145,12 +155,14 @@ def __init__(
action_size: int,
modules: DiscreteBCModules,
beta: float,
compiled: bool,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
compiled=compiled,
device=device,
)
self._beta = beta
Expand Down

0 comments on commit c6e6cd4

Please sign in to comment.