Skip to content

Commit

Permalink
Respect config when figuring out WD decoupling
Browse files Browse the repository at this point in the history
WD = weight decay, as in whether AdamW is selected.

Previously we just ignored `adam_w_mode`.

Based on @ofivite's suggestion.
  • Loading branch information
janEbert committed Aug 21, 2024
1 parent ed73196 commit 424f2fd
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,17 @@ def setup_optimizer_param_groups(self):
lr=self.cfg.optim.lr,
weight_decay=self.cfg.optim.get('weight_decay', 0.0),
)

decoupled_wd = None
if is_adam_opt(optim_name):
optim_kwargs['eps'] = self.cfg.optim.get('eps', 1e-8)
if optim_name in ['fused_adam', 'distributed_fused_adam', 'megatron_fused_adam']:
decoupled_wd = self.cfg.optim.get('adam_w_mode', None)

self._optimizer_param_groups = process_mup_param_groups(
optim_name,
self._optimizer_param_groups,
decoupled_wd=decoupled_wd,
**optim_kwargs,
)

Expand Down

0 comments on commit 424f2fd

Please sign in to comment.