Skip to content

Commit

Permalink
Remove not anymore required monkey patching
Browse files Browse the repository at this point in the history
Due to a more up-to-date Megatron-LM requirement, we can thankfully
remove this.
  • Loading branch information
janEbert committed Jun 19, 2024
1 parent cb50e22 commit 1f019e5
Showing 1 changed file with 0 additions and 80 deletions.
80 changes: 0 additions & 80 deletions nemo/collections/nlp/modules/common/megatron/mup/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,91 +156,11 @@ def __getattr__(self, name: str):
return getattr(original_linear, name)


# This is copied almost verbatim from
# `megatron/core/models/gpt/gpt_model.py`, but we modify the
# `output_layer_key` string.
def _patched_sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = ()) -> ShardedStateDict:
assert not sharded_offsets, "Unexpected sharded offsets"
sharded_state_dict = {}

if self.pre_process:
embedding_prefix = f'{prefix}embedding.'
embedding_sharded_state_dict = self.embedding.sharded_state_dict(
prefix=embedding_prefix
)
sharded_state_dict.update(embedding_sharded_state_dict)

decoder_prefix = f'{prefix}decoder.'
decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix)
sharded_state_dict.update(decoder_sharded_state_dict)

if self.post_process:
output_layer_prefix = f'{prefix}output_layer.'
output_layer_key = f'{output_layer_prefix}_original_linear.weight'
if self.share_embeddings_and_output_weights:
if not self.pre_process:
# when sharing embeddings with last stage, we need to use the weights from the first stage
# on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight
tensor = self.shared_embedding_or_output_weight()
first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
last_stage_word_emb_replica_id = (
1, # copy of first stage embedding
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)

sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)

sharded_state_dict[output_layer_key] = sharded_output_layer_tensor

else:
output_layer_state_dict = self.output_layer.state_dict(
prefix=output_layer_prefix, keep_vars=True
)
output_layer_tensor = output_layer_state_dict[output_layer_key]
# independent output layer
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=output_layer_tensor, key=output_layer_key, allow_shape_mismatch=True,
)

sharded_state_dict[output_layer_key] = sharded_output_layer_tensor

return sharded_state_dict


def patch_mcore_gptmodel_for_mup(model):
# If we don't have Megatron-LM, we have nothing to patch.
if MCoreGPTModel is None:
return

# Do some evil monkey patching in order to fix up state
# dicts. To make this safe, we explicitly check for whether
# our code matches and error out if it doesn't.
import hashlib
import inspect
import types

sharded_state_dict_code = MCoreGPTModel.sharded_state_dict
sharded_state_dict_func_hash = hashlib.md5(
inspect.getsource(sharded_state_dict_code).encode(),
).hexdigest()
assert (
sharded_state_dict_func_hash
== '192b67d1526c552d03ea830d2374657f'
), (
'cannot patch this version of Megatron-LM for μP. Please '
'update the state dict patching implementation to support it.'
)


model._old_sharded_state_dict = model.sharded_state_dict
model.sharded_state_dict = types.MethodType(_patched_sharded_state_dict, model)

model.output_layer = MCoreMuReadout(model.output_layer)


Expand Down

0 comments on commit 1f019e5

Please sign in to comment.