Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 11, 2024
1 parent 80d67fa commit f7e4fec
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 16 deletions.
11 changes: 11 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,19 @@
entry_points=True,
description=_ffns_with_norm_description)

_ffns_with_megablocks_description = (
'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.'
+ 'See ffn.py for examples.')
ffns_with_megablocks = create_registry(
'llmfoundry',
'ffns_with_megablocks',
generic_type=Callable,
entry_points=True,
description=_ffns_with_megablocks_description)

__all__ = [
'norms',
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
]
7 changes: 4 additions & 3 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from torch.distributed import ProcessGroup
from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard

from llmfoundry.layers_registry import ffns, ffns_with_norm
from llmfoundry.layers_registry import (ffns, ffns_with_megablocks,
ffns_with_norm)
from llmfoundry.models.layers.dmoe import dMoE
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY

Expand Down Expand Up @@ -444,5 +445,5 @@ def build_mb_dmoe(
ffns_with_norm.register('te_ln_mlp', func=build_te_ln_mlp)

if is_megablocks_imported:
ffns.register('mb_moe', func=build_mb_moe)
ffns.register('mb_dmoe', func=build_mb_dmoe)
ffns_with_megablocks.register('mb_moe', func=build_mb_moe)
ffns_with_megablocks.register('mb_dmoe', func=build_mb_dmoe)
9 changes: 8 additions & 1 deletion llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch

from llmfoundry.layers_registry import ffns, ffns_with_norm, norms
from llmfoundry.layers_registry import (ffns, ffns_with_megablocks,
ffns_with_norm, norms)
from llmfoundry.utils.registry_utils import construct_from_registry


Expand Down Expand Up @@ -37,6 +38,9 @@ def build_ffn(
if name in ffns_with_norm:
registry_to_use = ffns_with_norm

if name in ffns_with_megablocks:
registry_to_use = ffns_with_megablocks

kwargs = {
'd_model': d_model,
'expansion_ratio': expansion_ratio,
Expand All @@ -59,4 +63,7 @@ def _validation_function(maybe_module: Any):
if name in ffns_with_norm:
result._has_norm = True

if name in ffns_with_megablocks:
result._uses_megablocks = True

return result
3 changes: 2 additions & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from transformers import PretrainedConfig

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.layers.attention import (check_alibi_support,
is_flash_v2_installed)
from llmfoundry.models.layers.blocks import attn_config_defaults
Expand Down Expand Up @@ -290,7 +291,7 @@ def _validate_config(self) -> None:
)
elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] in ['mb_moe', 'mb_dmoe']:
elif self.ffn_config['ffn_type'] in ffns_with_megablocks:
self.ffn_config['return_bias'] = False
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
Expand Down
7 changes: 4 additions & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from composer.models import HuggingFaceModel
from composer.utils import dist

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.layers.attention import is_flash_v2_installed

if is_flash_v2_installed():
Expand Down Expand Up @@ -324,7 +325,7 @@ def __init__(self, config: MPTConfig):
self.emb_drop = nn.Dropout(config.emb_pdrop)
self.mb_args = None
block_args = config.to_dict()
if block_args['ffn_config']['ffn_type'] in ('mb_moe', 'mb_dmoe'):
if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks:
block_args['ffn_config'] = config_moe_args(
block_args['ffn_config'],
config.d_model,
Expand Down Expand Up @@ -1026,7 +1027,7 @@ def get_targets(self, batch: Mapping) -> torch.Tensor:
return targets

def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
if self.config.ffn_config['ffn_type'] in ffns_with_megablocks:
# Clear MegaBlocks MoE load balancing loss cache
try: # Add try/catch to avoid transformers complaining and raising errors
from megablocks.layers.moe import clear_load_balancing_loss
Expand All @@ -1053,7 +1054,7 @@ def loss(self, outputs: CausalLMOutputWithPast,
else:
loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum()

if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
if self.config.ffn_config['ffn_type'] in ffns_with_megablocks:
# MegaBlocks MoE load balancing loss
try: # Add try/catch to avoid transformers complaining and raising errors
from megablocks.layers.moe import batched_load_balancing_loss
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/models/utils/config_moe_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from packaging import version
from torch import distributed

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size


Expand Down Expand Up @@ -156,7 +157,7 @@ def config_moe_args(
Returns:
ffn_config (dict): FFN configuration with MoE configured.
"""
if ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
if ffn_config['ffn_type'] in ffns_with_megablocks:
return config_megablocks_moe_args(
ffn_config=ffn_config,
d_model=d_model,
Expand Down
10 changes: 6 additions & 4 deletions llmfoundry/models/utils/mpt_param_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from torch import Tensor, nn
from torch.distributed._tensor import DTensor

from llmfoundry.layers_registry import ffns_with_megablocks


def module_n_params(module: nn.Module) -> int:
"""Gets the number of parameters in this module excluding child modules.
Expand Down Expand Up @@ -127,7 +129,7 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore


def mpt_get_total_params(mpt_model) -> int: # type: ignore
"""Calculates the total paramter count of an MPT model.
"""Calculates the total parameter count of an MPT model.
Note: Must be called before model parameters are sharded by FSDP.
Expand All @@ -138,14 +140,14 @@ def mpt_get_total_params(mpt_model) -> int: # type: ignore
Returns:
An int for the total number of parameters in this MPT model.
"""
if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
return megablocks_n_total_params(mpt_model)
else:
return sum(p.numel() for p in mpt_model.parameters())


def mpt_get_active_params(mpt_model) -> int: # type: ignore
"""Calculates the total paramter count of an MPT model.
"""Calculates the total parameter count of an MPT model.
Note: Must be called before model parameters are sharded by FSDP.
Expand All @@ -156,7 +158,7 @@ def mpt_get_active_params(mpt_model) -> int: # type: ignore
Returns:
An int for the active number of parameters in this MPT model.
"""
if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
params = megablocks_n_active_params(mpt_model)
else:
params = sum(p.numel() for p in mpt_model.parameters())
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import ffns, ffns_with_norm, norms
from llmfoundry.layers_registry import (ffns, ffns_with_megablocks,
ffns_with_norm, norms)
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = (
Expand Down Expand Up @@ -123,4 +124,5 @@
'norms',
'ffns',
'ffns_with_norm',
'ffns_with_megablocks',
]
3 changes: 2 additions & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.utils import init_empty_weights

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,7 +132,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):

# Set ffn_config.device_mesh to fsdp_config.device_mesh
if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in {'mb_moe', 'mb_dmoe'}:
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
# Raise ValueError if not using device mesh with MoE expert parallelism
if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get(
'moe_world_size', 1) > 1:
Expand Down
3 changes: 2 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
install()
from llmfoundry.callbacks import AsyncEval
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_algorithm, build_callback,
build_composer_model, build_evaluators,
Expand Down Expand Up @@ -102,7 +103,7 @@ def validate_config(cfg: DictConfig):
)

if cfg.model.get('ffn_config', {}).get('ffn_type',
'mptmlp') in ('mb_moe', 'mb_dmoe'):
'mptmlp') in ffns_with_megablocks:
moe_world_size = cfg.model.get('ffn_config',
{}).get('moe_world_size', 1)
use_orig_params = cfg.get('fsdp_config',
Expand Down

0 comments on commit f7e4fec

Please sign in to comment.