Skip to content

Commit

Permalink
Refactor RETRO muP implementation
Browse files Browse the repository at this point in the history
Now same as for GPT model, re-using the (updated and fixed)
functionality.

Signed-off-by: janEbert <[email protected]>
  • Loading branch information
janEbert committed Jun 19, 2024
1 parent a6dd128 commit e383fb9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 49 deletions.
61 changes: 60 additions & 1 deletion examples/nlp/language_modeling/megatron_retro_cal_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin

from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.modules.common.megatron.mup.shape import make_base_shapes
from nemo.collections.nlp.modules.common.megatron.mup.shape import append_base_head_widths, make_base_shapes
from nemo.collections.nlp.parts.nlp_overrides import (
CustomProgressBar,
GradScaler,
Expand All @@ -34,6 +34,25 @@ def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

assert cfg.model.get('make_mup', False), \
'please configure `model.make_mup` to be `True` to calculate base shapes and make use of μP.'
assert cfg.model.get('shape_file', None), (
'please configure `model.shape_file` to point to a path in order to '
'save and later load the file containing base shapes.'
)
scalable_widths = cfg.model.get('mup_scalable_widths', [])
assert scalable_widths, (
'no `model.mup_scalable_widths` specified; need to specify config values to vary to create base shapes.'
)
shrink_factor = cfg.model.get('mup_delta_shrink_factor', 1)
assert (
# Do all scalable widths have a specified value?
all(isinstance(elem, (list, tuple)) and len(elem) == 2 for elem in scalable_widths)
or shrink_factor != 1
), (
'`model.mup_delta_shrink_factor` must be ≠1 if any scalable width does not have a specified value.'
)

megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
plugins = []
strategy = NLPDDPStrategy(
Expand Down Expand Up @@ -67,6 +86,40 @@ def main(cfg) -> None:
callbacks.append(CustomProgressBar())
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer, callbacks=callbacks)

with open_dict(cfg):
cfg.base_model = cfg.model.copy()
del cfg.base_model.shape_file
cfg.delta_model = cfg.model.copy()
del cfg.delta_model.shape_file
# Just to make sure the configs were actually deep-copied.
assert cfg.model.get('shape_file', None), \
'configs were not deep-copied; the OmegaConf copying code needs an update.'

# Vary delta model config
for elem in scalable_widths:
need_delta_value = True
# Get config key to set in `delta_model` config and optionally a specified value.
if isinstance(elem, str):
cfg_key = elem
else:
assert isinstance(elem, (list, tuple)) and 1 <= len(elem) <= 2
cfg_key = elem[0]
if len(elem) > 1:
delta_value = elem[1]
need_delta_value = False

base_value = OmegaConf.select(cfg.delta_model, cfg_key)

# If we don't have a specified delta value, calculate it automatically.
if need_delta_value:
delta_value = base_value // shrink_factor
assert delta_value > 0, 'value became ≤0 after shrinking'
assert isinstance(base_value, int) and isinstance(delta_value, int), \
'scalable width value needs to be an integer'
assert delta_value != base_value, 'scalable width delta value needs to be different from base value'

OmegaConf.update(cfg.delta_model, cfg_key, delta_value)

# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
with open_dict(cfg):
cfg.base_model.precision = cfg.trainer.precision
Expand All @@ -76,6 +129,12 @@ def main(cfg) -> None:
delta_model = MegatronRetrievalModel(cfg.delta_model, trainer)
make_base_shapes(base_model, delta_model, savefile=cfg.model.shape_file)

append_base_head_widths(
cfg.model.shape_file,
base_model,
['.self_attention', '.inter_attention', '.cross_attention', '.core_attention'],
)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,19 @@
from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.modules.common.megatron.mup.optim import MuAdam, MuAdamW
from nemo.collections.nlp.parts.nlp_overrides import (
CustomProgressBar,
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
)
from nemo.core.config import hydra_runner
from nemo.core.config.optimizers import AdamParams, AdamWParams
from nemo.core.optim.optimizers import register_optimizer
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf", config_name="megatron_retro_mutransfer")
def main(cfg) -> None:
register_optimizer("muadamw", MuAdamW, AdamWParams())
register_optimizer("muadam", MuAdam, AdamParams())
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
)
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.mup.init import normal_
from nemo.collections.nlp.modules.common.megatron.mup.shape import set_base_shapes
from nemo.collections.nlp.modules.common.megatron.mup.convert import maybe_mup_init
from nemo.collections.nlp.modules.common.megatron.mup.optim import process_mup_param_groups
from nemo.collections.nlp.modules.common.megatron.retrieval_token_level_encoder_decoder import (
MegatronRetrievalTokenLevelEncoderDecoderModule,
)
Expand Down Expand Up @@ -102,47 +102,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
True if (not self.megatron_amp_O2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False
)

if hasattr(self.cfg, "shape_file"):
set_base_shapes(self, self.register_artifact("shape_file", self.cfg.shape_file), rescale_params=False)

# here manually initialize all the named parameters with the muTranfer normal initializer
for name, tensor in self.named_parameters():
if name.endswith('.dense_4h_to_h.weight') or name.endswith('.dense.weight'):
# initialize all the output dense matrix weight
# match the megatron lm model
std = self.cfg.init_method_std / math.sqrt(2.0 * 12.0)
normal_(tensor, 0, std)
elif name.endswith('layernorm.weight'):
# initialize all the layer norm weight
if tensor.std() != 0 and tensor.mean() != 1:
raise ValueError(f'need to check {name} init')
normal_(tensor, 1, 0)
elif name.endswith('.weight'):
# initialize all the other dense matrix weight
normal_(tensor, 0, self.cfg.init_method_std)
else:
if tensor.std() != 0 and tensor.mean() != 0:
raise ValueError(f'need to check {name} init')

# here manually overwrite the norm factor
# note, has to turn off the model.apply_query_key_layer_scaling
assert not self.cfg.apply_query_key_layer_scaling
for name, layer in self.named_modules():
if (
name.endswith('.self_attention')
or name.endswith('.inter_attention')
or name.endswith('.cross_attention')
or name.endswith('.core_attention')
):
if hasattr(layer, 'norm_factor') and hasattr(layer, 'hidden_size_per_attention_head'):
layer.norm_factor = (
layer.hidden_size_per_attention_head / 8.0
) # divide 8 to make it consist with ADLR setting
else:
if hasattr(layer, 'norm_factor') or hasattr(layer, 'hidden_size_per_attention_head'):
logging.error(
f'module {name} has norm factor but its name is not ending with attention, need to double check'
)
maybe_mup_init(self)

def _build_tokenizer(self):
self.tokenizer = get_nmt_tokenizer(
Expand Down Expand Up @@ -573,5 +533,15 @@ def setup_optimizer_param_groups(self):
"""ModelPT override. Optimizer will get self._optimizer_param_groups"""
self._optimizer_param_groups = get_params_for_weight_decay_optimization([self.model])

if self.cfg.get('make_mup', False) and hasattr(self.cfg, 'shape_file'):
# muP parameter group processing
optim_name = self.cfg.optim.get('name', 'fused_adam')
self._optimizer_param_groups = process_mup_param_groups(
optim_name,
self._optimizer_param_groups,
lr=self.cfg.optim.lr,
weight_decay=self.cfg.optim.get('weight_decay', 0.0),
)

def list_available_models(self):
pass

0 comments on commit e383fb9

Please sign in to comment.