From 9a77f408123aec0e1a0254fd36fe8c7c6009c900 Mon Sep 17 00:00:00 2001 From: takuseno Date: Mon, 4 Nov 2024 23:49:17 +0900 Subject: [PATCH] Refactor DQN variants --- d3rlpy/algos/qlearning/bcq.py | 44 ++++++-- d3rlpy/algos/qlearning/cql.py | 44 ++++++-- d3rlpy/algos/qlearning/dqn.py | 70 +++++++++--- d3rlpy/algos/qlearning/functional.py | 92 ++++++++++++++++ d3rlpy/algos/qlearning/nfq.py | 38 +++++-- d3rlpy/algos/qlearning/torch/bcq_impl.py | 45 ++++---- d3rlpy/algos/qlearning/torch/cql_impl.py | 33 ++---- d3rlpy/algos/qlearning/torch/dqn_impl.py | 133 ++++++++++------------- 8 files changed, 335 insertions(+), 164 deletions(-) create mode 100644 d3rlpy/algos/qlearning/functional.py diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index f89db8d6..fba40c0a 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -16,12 +16,15 @@ from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field from ...types import Shape from .base import QLearningAlgoBase +from .functional import FunctionalQLearningAlgoImplBase from .torch.bcq_impl import ( BCQImpl, BCQModules, - DiscreteBCQImpl, + DiscreteBCQActionSampler, + DiscreteBCQLossFn, DiscreteBCQModules, ) +from .torch.dqn_impl import DQNUpdater, DQNValuePredictor __all__ = ["BCQConfig", "BCQ", "DiscreteBCQConfig", "DiscreteBCQ"] @@ -363,7 +366,7 @@ def get_type() -> str: return "discrete_bcq" -class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]): +class DiscreteBCQ(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DiscreteBCQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -422,17 +425,38 @@ def inner_create_impl( optim=optim, ) - self._impl = DiscreteBCQImpl( - observation_shape=observation_shape, - action_size=action_size, + # build functional components + updater = DQNUpdater( modules=modules, - q_func_forwarder=q_func_forwarder, - targ_q_func_forwarder=targ_q_func_forwarder, + dqn_loss_fn=DiscreteBCQLossFn( + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=self._config.gamma, + beta=self._config.beta, + ), target_update_interval=self._config.target_update_interval, - gamma=self._config.gamma, - action_flexibility=self._config.action_flexibility, - beta=self._config.beta, compiled=self.compiled, + ) + action_sampler = DiscreteBCQActionSampler( + modules=modules, + q_func_forwarder=q_func_forwarder, + action_flexibility=self._config.action_flexibility, + ) + value_predictor = DQNValuePredictor(q_func_forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + updater=updater, + exploit_action_sampler=action_sampler, + explore_action_sampler=action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=optim.optim, + policy=None, + policy_optim=None, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index de4fb79c..53fceacb 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -14,8 +14,14 @@ from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field from ...types import Shape from .base import QLearningAlgoBase -from .torch.cql_impl import CQLImpl, CQLModules, DiscreteCQLImpl -from .torch.dqn_impl import DQNModules +from .functional import FunctionalQLearningAlgoImplBase +from .torch.cql_impl import CQLImpl, CQLModules, DiscreteCQLLossFn +from .torch.dqn_impl import ( + DQNActionSampler, + DQNModules, + DQNUpdater, + DQNValuePredictor, +) __all__ = ["CQLConfig", "CQL", "DiscreteCQLConfig", "DiscreteCQL"] @@ -304,7 +310,7 @@ def get_type() -> str: return "discrete_cql" -class DiscreteCQL(QLearningAlgoBase[DiscreteCQLImpl, DiscreteCQLConfig]): +class DiscreteCQL(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DiscreteCQLConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -339,16 +345,34 @@ def inner_create_impl( optim=optim, ) - self._impl = DiscreteCQLImpl( - observation_shape=observation_shape, - action_size=action_size, + # build functional components + updater = DQNUpdater( modules=modules, - q_func_forwarder=q_func_forwarder, - targ_q_func_forwarder=targ_q_func_forwarder, + dqn_loss_fn=DiscreteCQLLossFn( + action_size=action_size, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=self._config.gamma, + alpha=self._config.alpha, + ), target_update_interval=self._config.target_update_interval, - gamma=self._config.gamma, - alpha=self._config.alpha, compiled=self.compiled, + ) + action_sampler = DQNActionSampler(q_func_forwarder) + value_predictor = DQNValuePredictor(q_func_forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + updater=updater, + exploit_action_sampler=action_sampler, + explore_action_sampler=action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=optim.optim, + policy=None, + policy_optim=None, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 4e993b0c..f2bc3a38 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -8,7 +8,15 @@ from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field from ...types import Shape from .base import QLearningAlgoBase -from .torch.dqn_impl import DoubleDQNImpl, DQNImpl, DQNModules +from .functional import FunctionalQLearningAlgoImplBase +from .torch.dqn_impl import ( + DoubleDQNLossFn, + DQNActionSampler, + DQNLossFn, + DQNModules, + DQNUpdater, + DQNValuePredictor, +) __all__ = ["DQNConfig", "DQN", "DoubleDQNConfig", "DoubleDQN"] @@ -66,7 +74,7 @@ def get_type() -> str: return "dqn" -class DQN(QLearningAlgoBase[DQNImpl, DQNConfig]): +class DQN(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DQNConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -101,15 +109,32 @@ def inner_create_impl( optim=optim, ) - self._impl = DQNImpl( + # build functional components + updater = DQNUpdater( + modules=modules, + dqn_loss_fn=DQNLossFn( + q_func_forwarder=forwarder, + targ_q_func_forwarder=targ_forwarder, + gamma=self._config.gamma, + ), + target_update_interval=self._config.target_update_interval, + compiled=self.compiled, + ) + action_sampler = DQNActionSampler(forwarder) + value_predictor = DQNValuePredictor(forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( observation_shape=observation_shape, action_size=action_size, - q_func_forwarder=forwarder, - targ_q_func_forwarder=targ_forwarder, - target_update_interval=self._config.target_update_interval, modules=modules, - gamma=self._config.gamma, - compiled=self.compiled, + updater=updater, + exploit_action_sampler=action_sampler, + explore_action_sampler=action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=optim.optim, + policy=None, + policy_optim=None, device=self._device, ) @@ -212,15 +237,32 @@ def inner_create_impl( optim=optim, ) - self._impl = DoubleDQNImpl( - observation_shape=observation_shape, - action_size=action_size, + # build functional components + updater = DQNUpdater( modules=modules, - q_func_forwarder=forwarder, - targ_q_func_forwarder=targ_forwarder, + dqn_loss_fn=DoubleDQNLossFn( + q_func_forwarder=forwarder, + targ_q_func_forwarder=targ_forwarder, + gamma=self._config.gamma, + ), target_update_interval=self._config.target_update_interval, - gamma=self._config.gamma, compiled=self.compiled, + ) + action_sampler = DQNActionSampler(forwarder) + value_predictor = DQNValuePredictor(forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + updater=updater, + exploit_action_sampler=action_sampler, + explore_action_sampler=action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=optim.optim, + policy=None, + policy_optim=None, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/functional.py b/d3rlpy/algos/qlearning/functional.py new file mode 100644 index 00000000..c818f136 --- /dev/null +++ b/d3rlpy/algos/qlearning/functional.py @@ -0,0 +1,92 @@ +from typing import Optional, Protocol + +import torch +from torch import nn + +from ...models.torch.policies import Policy +from ...torch_utility import Modules, TorchMiniBatch +from ...types import Shape, TorchObservation +from .base import QLearningAlgoImplBase + +__all__ = ["Updater", "ActionSampler", "ValuePredictor", "FunctionalQLearningAlgoImplBase"] + + +class Updater(Protocol): + def __call__(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]: + ... + + +class ActionSampler(Protocol): + def __call__(self, x: TorchObservation) -> torch.Tensor: + ... + + +class ValuePredictor(Protocol): + def __call__(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor: + ... + + +class FunctionalQLearningAlgoImplBase(QLearningAlgoImplBase): + def __init__( + self, + observation_shape: Shape, + action_size: int, + modules: Modules, + updater: Updater, + exploit_action_sampler: ActionSampler, + explore_action_sampler: ActionSampler, + value_predictor: ValuePredictor, + q_function: nn.ModuleList, + q_function_optim: torch.optim.Optimizer, + policy: Optional[Policy], + policy_optim: Optional[torch.optim.Optimizer], + device: str, + ): + super().__init__( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + device=device, + ) + self._updater = updater + self._exploit_action_sampler = exploit_action_sampler + self._explore_action_sampler = explore_action_sampler + self._value_predictor = value_predictor + self._q_function = q_function + self._q_function_optim = q_function_optim + self._policy = policy + self._policy_optim = policy_optim + + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> dict[str, float]: + return self._updater(batch, grad_step) + + def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: + return self._exploit_action_sampler(x) + + def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: + return self._explore_action_sampler(x) + + def inner_predict_value( + self, x: TorchObservation, action: torch.Tensor + ) -> torch.Tensor: + return self._value_predictor(x, action) + + @property + def policy(self) -> Policy: + assert self._policy + return self._policy + + @property + def policy_optim(self) -> torch.optim.Optimizer: + assert self._policy_optim + return self._policy_optim + + @property + def q_function(self) -> nn.ModuleList: + return self._q_function + + @property + def q_function_optim(self) -> torch.optim.Optimizer: + return self._q_function_optim diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 634f5c27..9dc1d637 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -8,7 +8,14 @@ from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field from ...types import Shape from .base import QLearningAlgoBase -from .torch.dqn_impl import DQNImpl, DQNModules +from .functional import FunctionalQLearningAlgoImplBase +from .torch.dqn_impl import ( + DQNActionSampler, + DQNLossFn, + DQNModules, + DQNUpdater, + DQNValuePredictor, +) __all__ = ["NFQConfig", "NFQ"] @@ -68,7 +75,7 @@ def get_type() -> str: return "nfq" -class NFQ(QLearningAlgoBase[DQNImpl, NFQConfig]): +class NFQ(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, NFQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -103,15 +110,32 @@ def inner_create_impl( optim=optim, ) - self._impl = DQNImpl( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, + loss_fn = DQNLossFn( q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, - target_update_interval=1, gamma=self._config.gamma, + ) + updater = DQNUpdater( + modules=modules, + dqn_loss_fn=loss_fn, + target_update_interval=1, compiled=self.compiled, + ) + action_sampler = DQNActionSampler(q_func_forwarder) + value_predictor = DQNValuePredictor(q_func_forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + updater=updater, + exploit_action_sampler=action_sampler, + explore_action_sampler=action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=optim.optim, + policy=None, + policy_optim=None, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 6c77a4bb..67823b81 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -27,15 +27,17 @@ soft_sync, ) from ....types import Shape, TorchObservation +from ..functional import ActionSampler from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules -from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules +from .dqn_impl import DoubleDQNLossFn, DQNLoss, DQNModules __all__ = [ "BCQImpl", - "DiscreteBCQImpl", "BCQModules", "DiscreteBCQModules", "DiscreteBCQLoss", + "DiscreteBCQLossFn", + "DiscreteBCQActionSampler", ] @@ -240,43 +242,25 @@ class DiscreteBCQLoss(DQNLoss): imitator_loss: torch.Tensor -class DiscreteBCQImpl(DoubleDQNImpl): - _modules: DiscreteBCQModules - _action_flexibility: float - _beta: float - +class DiscreteBCQLossFn(DoubleDQNLossFn): def __init__( self, - observation_shape: Shape, - action_size: int, modules: DiscreteBCQModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - target_update_interval: int, gamma: float, - action_flexibility: float, beta: float, - compiled: bool, - device: str, ): super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, - target_update_interval=target_update_interval, gamma=gamma, - compiled=compiled, - device=device, ) - self._action_flexibility = action_flexibility + self._modules = modules self._beta = beta - def compute_loss( - self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> DiscreteBCQLoss: - td_loss = super().compute_loss(batch, q_tpn).loss + def __call__(self, batch: TorchMiniBatch) -> DiscreteBCQLoss: + td_loss = super().__call__(batch).loss imitator_loss = compute_discrete_imitation_loss( policy=self._modules.imitator, x=batch.observations, @@ -288,7 +272,18 @@ def compute_loss( loss=loss, td_loss=td_loss, imitator_loss=imitator_loss.loss ) - def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: +class DiscreteBCQActionSampler(ActionSampler): + def __init__( + self, + modules: DiscreteBCQModules, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + action_flexibility: float, + ): + self._modules = modules + self._q_func_forwarder = q_func_forwarder + self._action_flexibility = action_flexibility + + def __call__(self, x: TorchObservation) -> torch.Tensor: dist = self._modules.imitator(x) log_probs = F.log_softmax(dist.logits, dim=1) ratio = log_probs - log_probs.max(dim=1, keepdim=True).values diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index f3eb7955..caf673b3 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -19,11 +19,11 @@ ) from ....types import Shape, TorchObservation from .ddpg_impl import DDPGBaseCriticLoss -from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules +from .dqn_impl import DoubleDQNLossFn, DQNLoss from .sac_impl import SACImpl, SACModules from .utility import sample_q_values_with_policy -__all__ = ["CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss"] +__all__ = ["CQLImpl", "DiscreteCQLLossFn", "CQLModules", "DiscreteCQLLoss"] @dataclasses.dataclass(frozen=True) @@ -234,33 +234,17 @@ class DiscreteCQLLoss(DQNLoss): conservative_loss: torch.Tensor -class DiscreteCQLImpl(DoubleDQNImpl): - _alpha: float - +class DiscreteCQLLossFn(DoubleDQNLossFn): def __init__( self, - observation_shape: Shape, action_size: int, - modules: DQNModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - target_update_interval: int, gamma: float, alpha: float, - compiled: bool, - device: str, ): - super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, - q_func_forwarder=q_func_forwarder, - targ_q_func_forwarder=targ_q_func_forwarder, - target_update_interval=target_update_interval, - gamma=gamma, - compiled=compiled, - device=device, - ) + super().__init__(q_func_forwarder, targ_q_func_forwarder, gamma) + self._action_size = action_size self._alpha = alpha def _compute_conservative_loss( @@ -271,17 +255,16 @@ def _compute_conservative_loss( logsumexp = torch.logsumexp(values, dim=1, keepdim=True) # estimate action-values under data distribution - one_hot = F.one_hot(act_t.view(-1), num_classes=self.action_size) + one_hot = F.one_hot(act_t.view(-1), num_classes=self._action_size) data_values = (values * one_hot).sum(dim=1, keepdim=True) return (logsumexp - data_values).mean() - def compute_loss( + def __call__( self, batch: TorchMiniBatch, - q_tpn: torch.Tensor, ) -> DiscreteCQLLoss: - td_loss = super().compute_loss(batch, q_tpn).loss + td_loss = super().__call__(batch).loss conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions.long() ) diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index 859b8b19..8db4a569 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -1,9 +1,7 @@ import dataclasses -from typing import Callable import torch from torch import nn -from torch.optim import Optimizer from ....dataclass_utils import asdict_as_float from ....models.torch import DiscreteEnsembleQFunctionForwarder @@ -14,11 +12,17 @@ TorchMiniBatch, hard_sync, ) -from ....types import Shape, TorchObservation -from ..base import QLearningAlgoImplBase -from .utility import DiscreteQFunctionMixin +from ....types import TorchObservation +from ..functional import ActionSampler, Updater, ValuePredictor -__all__ = ["DQNImpl", "DQNModules", "DQNLoss", "DoubleDQNImpl"] +__all__ = [ + "DQNModules", + "DQNLoss", + "DQNLossFn", + "DoubleDQNLossFn", + "DQNActionSampler", + "DQNActionSampler", +] @dataclasses.dataclass(frozen=True) @@ -33,64 +37,19 @@ class DQNLoss: loss: torch.Tensor -class DQNImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): - _modules: DQNModules - _compute_grad: Callable[[TorchMiniBatch], DQNLoss] - _gamma: float - _q_func_forwarder: DiscreteEnsembleQFunctionForwarder - _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder - _target_update_interval: int - +class DQNLossFn: def __init__( self, - observation_shape: Shape, - action_size: int, - modules: DQNModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, - target_update_interval: int, gamma: float, - compiled: bool, - device: str, ): - super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, - device=device, - ) - self._gamma = gamma self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder - self._target_update_interval = target_update_interval - self._compute_grad = ( - CudaGraphWrapper(self.compute_grad) - if compiled - else self.compute_grad - ) - hard_sync(modules.targ_q_funcs, modules.q_funcs) + self._gamma = gamma - def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss: - self._modules.optim.zero_grad() + def __call__(self, batch: TorchMiniBatch) -> DQNLoss: q_tpn = self.compute_target(batch) - loss = self.compute_loss(batch, q_tpn) - loss.loss.backward() - return loss - - def inner_update( - self, batch: TorchMiniBatch, grad_step: int - ) -> dict[str, float]: - loss = self._compute_grad(batch) - self._modules.optim.step() - if grad_step % self._target_update_interval == 0: - self.update_target() - return asdict_as_float(loss) - - def compute_loss( - self, - batch: TorchMiniBatch, - q_tpn: torch.Tensor, - ) -> DQNLoss: loss = self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions.long(), @@ -113,30 +72,58 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: reduction="min", ) - def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: - return self._q_func_forwarder.compute_expected_q(x).argmax(dim=1) - def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: - return self.inner_predict_best_action(x) - - def update_target(self) -> None: - hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) - - @property - def q_function(self) -> nn.ModuleList: - return self._modules.q_funcs - - @property - def q_function_optim(self) -> Optimizer: - return self._modules.optim.optim - - -class DoubleDQNImpl(DQNImpl): +class DoubleDQNLossFn(DQNLossFn): def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - action = self.inner_predict_best_action(batch.next_observations) + action = self._q_func_forwarder.compute_expected_q(batch.next_observations).argmax(dim=1) return self._targ_q_func_forwarder.compute_target( batch.next_observations, action, reduction="min", ) + + +class DQNActionSampler(ActionSampler): + def __init__(self, q_func_forwarder: DiscreteEnsembleQFunctionForwarder): + self._q_func_forwarder = q_func_forwarder + + def __call__(self, x: TorchObservation) -> torch.Tensor: + return self._q_func_forwarder.compute_expected_q(x).argmax(dim=1) + + +class DQNValuePredictor(ValuePredictor): + def __init__(self, q_func_forwarder: DiscreteEnsembleQFunctionForwarder): + self._q_func_forwarder = q_func_forwarder + + def __call__(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor: + values = self._q_func_forwarder.compute_expected_q(x, reduction="mean") + flat_action = action.reshape(-1) + return values[torch.arange(0, values.size(0)), flat_action].reshape(-1) + + +class DQNUpdater(Updater): + def __init__( + self, + modules: DQNModules, + dqn_loss_fn: DQNLossFn, + target_update_interval: int, + compiled: bool, + ): + self._modules = modules + self._dqn_loss_fn = dqn_loss_fn + self._target_update_interval = target_update_interval + self._compute_grad = CudaGraphWrapper(self.compute_grad) if compiled else self.compute_grad + + def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss: + self._modules.optim.zero_grad() + loss = self._dqn_loss_fn(batch) + loss.loss.backward() + return loss + + def __call__(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]: + loss = self._compute_grad(batch) + self._modules.optim.step() + if grad_step % self._target_update_interval == 0: + hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) + return asdict_as_float(loss)