Skip to content

Commit

Permalink
Fix coord check with updated muP API
Browse files Browse the repository at this point in the history
Now scales all specified widths.

Also handle shrink factors potentially being too high.
  • Loading branch information
janEbert committed Aug 15, 2024
1 parent a873636 commit ec55113
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions examples/nlp/language_modeling/megatron_gpt_coord_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,39 @@ def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

shrink_factors = [1, 2, 4, 8, 16, 32]
base_width = cfg.model.hidden_size
base_ffn_width = cfg.model.ffn_hidden_size
scalable_widths = cfg.model.get('mup_scalable_widths', [])
assert scalable_widths, (
'no `model.mup_scalable_widths` specified; need to specify config values to vary to do coordinate check.'
)

base_widths = {}
for elem in scalable_widths:
# Get config key to query from `model` config.
if isinstance(scalable_widths, dict) or isinstance(elem, str):
cfg_key = elem
else:
assert isinstance(elem, (list, tuple)) and 1 <= len(elem) <= 2
cfg_key = elem[0]

base_value = OmegaConf.select(cfg.model, cfg_key)
base_widths[cfg_key] = base_value

shrink_factors = [1]
for shrink_factor in [2, 4, 8, 16, 32]:
if any(base_value // shrink_factor <= 0 for base_value in base_widths.values()):
break
shrink_factors.append(shrink_factor)
df = []

for shrink_factor in shrink_factors:
width = base_width // shrink_factor
ffn_width = base_ffn_width // shrink_factor

# `set_base_shapes` returns the model
new_cfg = cfg.copy()
if hasattr(new_cfg.model.optim, 'sched'):
del new_cfg.model.optim.sched
new_cfg.model.hidden_size = width
new_cfg.model.ffn_hidden_size = ffn_width

for (cfg_key, base_value) in base_widths.items():
delta_value = base_value // shrink_factor
OmegaConf.update(new_cfg.model, cfg_key, delta_value)
trainer = MegatronTrainerBuilder(new_cfg).create_trainer()

model = CoordCheckMegatronGPTModel(new_cfg.model, trainer)
Expand Down

0 comments on commit ec55113

Please sign in to comment.