Skip to content

Commit

Permalink
Also scale Adam epsilon
Browse files Browse the repository at this point in the history
  • Loading branch information
janEbert committed Aug 13, 2024
1 parent b919fd2 commit a873636
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from nemo.collections.nlp.modules.common.megatron.build_model import build_model
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.mup.convert import maybe_mup_init
from nemo.collections.nlp.modules.common.megatron.mup.optim import process_mup_param_groups
from nemo.collections.nlp.modules.common.megatron.mup.optim import is_adam_opt, process_mup_param_groups
from nemo.collections.nlp.modules.common.megatron.utils import (
ApexGuardDefaults,
average_losses_across_data_parallel_group,
Expand Down Expand Up @@ -526,11 +526,16 @@ def setup_optimizer_param_groups(self):
if self.cfg.get('make_mup', False) and hasattr(self.cfg, 'shape_file'):
# muP parameter group processing
optim_name = self.cfg.optim.get('name', 'fused_adam')
optim_kwargs = dict(
lr=self.cfg.optim.lr,
weight_decay=self.cfg.optim.get('weight_decay', 0.0),
)
if is_adam_opt(optim_name):
optim_kwargs['eps'] = self.cfg.optim.get('eps', 1e-8)
self._optimizer_param_groups = process_mup_param_groups(
optim_name,
self._optimizer_param_groups,
lr=self.cfg.optim.lr,
weight_decay=self.cfg.optim.get('weight_decay', 0.0),
**optim_kwargs,
)

def setup_mcore_distributed_parallel(self):
Expand Down
22 changes: 15 additions & 7 deletions nemo/collections/nlp/modules/common/megatron/mup/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def MuOptimizer(params, **kwargs):
from torch.optim import SGD, Adam, AdamW


def is_adam_opt(optim_name):
return optim_name in [
'adam',
'adamw',
'fused_adam',
'distributed_fused_adam',
'megatron_fused_adam',
]


def process_param_groups(params, **kwargs):
param_groups = list(params)
if not isinstance(param_groups[0], dict):
Expand All @@ -72,17 +82,13 @@ def process_param_groups(params, **kwargs):
param_group['lr'] = kwargs['lr']
if 'weight_decay' not in param_group:
param_group['weight_decay'] = kwargs.get('weight_decay', 0.0)
if 'eps' not in param_group and 'eps' in kwargs:
param_group['eps'] = kwargs['eps']
return param_groups


def process_mup_param_groups(optim_name, params, decoupled_wd=None, **kwargs):
if optim_name in [
'adam',
'adamw',
'fused_adam',
'distributed_fused_adam',
'megatron_fused_adam',
]:
if is_adam_opt(optim_name):
if decoupled_wd is None:
decoupled_wd = optim_name != 'adam'
param_groups = process_adam_param_groups(params, decoupled_wd=decoupled_wd, **kwargs)
Expand Down Expand Up @@ -124,6 +130,8 @@ def new_group():
group['lr'] /= width_mult
if not decoupled_wd:
group['weight_decay'] *= width_mult
if 'eps' in group:
group['eps'] /= width_mult
new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p])
return new_param_groups

Expand Down

0 comments on commit a873636

Please sign in to comment.