diff --git a/d3rlpy/algos/qlearning/torch/rebrac_impl.py b/d3rlpy/algos/qlearning/torch/rebrac_impl.py index bb4acef7..ae8262ac 100644 --- a/d3rlpy/algos/qlearning/torch/rebrac_impl.py +++ b/d3rlpy/algos/qlearning/torch/rebrac_impl.py @@ -56,10 +56,12 @@ def compute_actor_loss( reduction="min", ) lam = 1 / (q_t.abs().mean()).detach() - bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean() + bc_loss = ((batch.actions - action.squashed_mu) ** 2).sum( + dim=1, keepdim=True + ) return TD3PlusBCActorLoss( - actor_loss=lam * -q_t.mean() + self._actor_beta * bc_loss, - bc_loss=bc_loss, + actor_loss=(lam * -q_t + self._actor_beta * bc_loss).mean(), + bc_loss=bc_loss.mean(), ) def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: