diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 3f1cc648f3d8..5ed09fece3f5 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -49,7 +49,7 @@ from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.mup.convert import maybe_mup_init -from nemo.collections.nlp.modules.common.megatron.mup.optim import process_mup_param_groups +from nemo.collections.nlp.modules.common.megatron.mup.optim import is_adam_opt, process_mup_param_groups from nemo.collections.nlp.modules.common.megatron.utils import ( ApexGuardDefaults, average_losses_across_data_parallel_group, @@ -526,11 +526,16 @@ def setup_optimizer_param_groups(self): if self.cfg.get('make_mup', False) and hasattr(self.cfg, 'shape_file'): # muP parameter group processing optim_name = self.cfg.optim.get('name', 'fused_adam') + optim_kwargs = dict( + lr=self.cfg.optim.lr, + weight_decay=self.cfg.optim.get('weight_decay', 0.0), + ) + if is_adam_opt(optim_name): + optim_kwargs['eps'] = self.cfg.optim.get('eps', 1e-8) self._optimizer_param_groups = process_mup_param_groups( optim_name, self._optimizer_param_groups, - lr=self.cfg.optim.lr, - weight_decay=self.cfg.optim.get('weight_decay', 0.0), + **optim_kwargs, ) def setup_mcore_distributed_parallel(self): diff --git a/nemo/collections/nlp/modules/common/megatron/mup/optim.py b/nemo/collections/nlp/modules/common/megatron/mup/optim.py index daee5c3438a2..f6ca24afd356 100644 --- a/nemo/collections/nlp/modules/common/megatron/mup/optim.py +++ b/nemo/collections/nlp/modules/common/megatron/mup/optim.py @@ -63,6 +63,16 @@ def MuOptimizer(params, **kwargs): from torch.optim import SGD, Adam, AdamW +def is_adam_opt(optim_name): + return optim_name in [ + 'adam', + 'adamw', + 'fused_adam', + 'distributed_fused_adam', + 'megatron_fused_adam', + ] + + def process_param_groups(params, **kwargs): param_groups = list(params) if not isinstance(param_groups[0], dict): @@ -72,17 +82,13 @@ def process_param_groups(params, **kwargs): param_group['lr'] = kwargs['lr'] if 'weight_decay' not in param_group: param_group['weight_decay'] = kwargs.get('weight_decay', 0.0) + if 'eps' not in param_group and 'eps' in kwargs: + param_group['eps'] = kwargs['eps'] return param_groups def process_mup_param_groups(optim_name, params, decoupled_wd=None, **kwargs): - if optim_name in [ - 'adam', - 'adamw', - 'fused_adam', - 'distributed_fused_adam', - 'megatron_fused_adam', - ]: + if is_adam_opt(optim_name): if decoupled_wd is None: decoupled_wd = optim_name != 'adam' param_groups = process_adam_param_groups(params, decoupled_wd=decoupled_wd, **kwargs) @@ -124,6 +130,8 @@ def new_group(): group['lr'] /= width_mult if not decoupled_wd: group['weight_decay'] *= width_mult + if 'eps' in group: + group['eps'] /= width_mult new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p]) return new_param_groups