Skip to content

Commit

Permalink
refac
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Jan 18, 2024
1 parent 4232be0 commit 78a33e7
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""

from __future__ import annotations

import math
import warnings
from types import MethodType
from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple,
Union)

Expand Down Expand Up @@ -294,7 +295,7 @@ class MPTPreTrainedModel(PreTrainedModel):


def _fsdp_wrap_fn(
self, # type: ignore[no-untyped-def]
self: Union[MPTModel, MPTForCausalLM],
module: nn.Module,
) -> bool:
# FSDP Wrap function for MPT Models
Expand Down Expand Up @@ -396,9 +397,6 @@ def __init__(self, config: MPTConfig):
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

# attach fsdp wrapping function
self.fsdp_wrap_fn = MethodType(_fsdp_wrap_fn, self)

def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.wte

Expand Down Expand Up @@ -738,6 +736,10 @@ def param_init_fn(self, module: nn.Module) -> None:
**self.config.init_config,
)

# FSDP Wrap function
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
return _fsdp_wrap_fn(self, module)

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
Expand Down Expand Up @@ -781,9 +783,6 @@ def __init__(self, config: MPTConfig):
)
self.logit_scale = logit_scale

# attach fsdp wrapping function
self.fsdp_wrap_fn = MethodType(_fsdp_wrap_fn, self)

def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.transformer.get_input_embeddings()

Expand Down Expand Up @@ -898,6 +897,10 @@ def param_init_fn(self, module: nn.Module) -> None:
**self.config.init_config,
)

# FSDP Wrap function
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
return _fsdp_wrap_fn(self, module)

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
Expand Down

0 comments on commit 78a33e7

Please sign in to comment.