Skip to content

Commit

Permalink
Fix details of ReBRAC reproduction
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 9, 2024
1 parent 7e72150 commit 1f8942b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions d3rlpy/algos/qlearning/torch/rebrac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> TD3PlusBCActorLoss:
q_t = self._q_func_forwarder.compute_expected_q(
batch.observations, action.squashed_mu, "none"
)[0]
batch.observations,
action.squashed_mu,
reduction="min",
)
lam = 1 / (q_t.abs().mean()).detach()
bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean()
return TD3PlusBCActorLoss(
Expand All @@ -78,6 +80,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
)

# BRAC reguralization
bc_loss = (clipped_action - batch.next_actions) ** 2
bc_penalty = ((clipped_action - batch.next_actions) ** 2).sum(
dim=1, keepdim=True
)

return next_q - self._critic_beta * bc_loss.sum(dim=1, keepdim=True)
return next_q - self._critic_beta * bc_penalty
2 changes: 1 addition & 1 deletion reproductions/offline/rebrac.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main() -> None:

rebrac.fit(
dataset,
n_steps=500000,
n_steps=1000000,
n_steps_per_epoch=1000,
save_interval=10,
evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
Expand Down

0 comments on commit 1f8942b

Please sign in to comment.