Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 11, 2024
2 parents 6016dac + b5fc0fa commit 073eb7e
Show file tree
Hide file tree
Showing 32 changed files with 2,420 additions and 158 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import \
MegaBlocksMoE_TokPerExpert
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
Expand All @@ -34,6 +36,7 @@
callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
callbacks.register('oom_observer', func=OOMObserver)
callbacks.register('eval_output_logging', func=EvalOutputLogging)
callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert)

callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
Expand All @@ -46,6 +49,7 @@
'ScheduledGarbageCollector',
'EvalGauntlet',
'HuggingFaceCheckpointer',
'MegaBlocksMoE_TokPerExpert',
'AsyncEval',
'CurriculumLearning',
]
70 changes: 58 additions & 12 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional, Sequence, Union

import torch
import torch.nn as nn
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
Expand All @@ -24,6 +25,7 @@
parse_uri)
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from packaging import version
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
Expand Down Expand Up @@ -312,28 +314,72 @@ def _save_checkpoint(self, state: State, logger: Logger):
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

state_dict_context = fsdp_state_dict_type_context(
original_model,
state_dict_type='full') if ((not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()

with state_dict_context:
state_dict = state_dict_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
StateDictOptions, get_model_state_dict)
cpu_offload = True

# Add a dtensor->cpu tensor hook to avoid CUDA OOM
def dtensor_to_tensor_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
dtensor_fqns = []
for fqn in state_dict.keys():
tensor = state_dict[fqn]
if isinstance(tensor, DTensor):
dtensor_fqns.append(fqn)
tensor = tensor.full_tensor() # type: ignore
if dist.get_global_rank() == 0:
if cpu_offload:
tensor = tensor.cpu()
state_dict[fqn] = tensor
if dist.get_global_rank() != 0:
for fqn in dtensor_fqns:
del state_dict[fqn]
return state_dict

hooks = []
for _, module in state_dict_model.named_modules():
if isinstance(module, FSDP):
hooks.append(
module._register_state_dict_hook(
dtensor_to_tensor_hook))

state_dict = get_model_state_dict(state_dict_model,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=cpu_offload))
for hook in hooks:
hook.remove()
else:
state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()
with state_dict_context:
state_dict = state_dict_model.state_dict()

# Convert the state dict to the requested precis
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

new_model_instance = None # Need this for pyright because variable could be unbound

if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')

# Edit HF config before building 2nd model copy
copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'
if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}):
copied_config.ffn_config['moe_world_size'] = 1

log.debug(f'Creating new model instance')

Expand Down
140 changes: 140 additions & 0 deletions llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Log tokens per expert for MegaBlocks MoE."""
from __future__ import annotations

import torch
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist


class MegaBlocksMoE_TokPerExpert(Callback):
"""Log tokens per expert for MegaBlocks MoE.
To compute the load balancing loss, MegaBlocks caches information including `tokens_per_expert`
(tpe). At the :attr:`.Event.BATCH_END` event this callback gets load_balancing_loss from
MegaBlocks to get `tokens_per_expert` then logs statistics (<STAT>) of the number of tokens
assigned to experts for each layer index (l_idx) under ``mb_moe/layer<l_idx>_<STAT>_tpe``.
The tokens_per_expert statistics are logged by the :class:`.Logger` to the following keys as
described below.
+----------------------------------+-----------------------------------------------------------+
| Key | Logged data |
+==================================+===========================================================+
| `mb_moe/alllayer_min_tpe` | Minimum tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_max_tpe` | Maximum tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_median_tpe` | Median tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_std_tpe` | Standard deviation of tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_min_tpe` | Minimum tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_max_tpe` | Maximum tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_median_tpe` | Median tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_std_tpe` | Standard deviation of tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
Args:
log_interval (int, optional): The interval on which to log (Default: 10).
log_every_layer (bool, optional): Enable logging ever layer's statisictics (True) or log
only aggregate statistics (Default: False).
all_reduce_stats (bool, optional): Enable aggregating statistics across gpus (True) or log
statistics for GPU 0 (Default: False).
normalize (bool, optional): Normalize token counts by total tokens (Default: True) or output
raw token count (False). When normalize is True, the callback displays the fraction of
unique tokens routed to each expert. When normalize is False, the callback displays the
total number of tokens routed to each expert.
"""

def __init__(
self,
log_interval: int = 10,
log_every_layer: bool = False,
all_reduce_stats: bool = False,
normalize: bool = True,
):
self.log_interval = log_interval
self.log_every_layer = log_every_layer
self.all_reduce_stats = all_reduce_stats
self.normalize = normalize

self.topk = None

def fit_start(self, state: State, logger: Logger) -> None:
if self.topk is None and self.normalize:
try:
from megablocks.layers.dmoe import dMoE
from megablocks.layers.moe import MoE
except:
raise RuntimeError(
'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
)
for module in state.model.modules():
if isinstance(module, (MoE, dMoE)):
self.topk = module.experts.args.moe_top_k
return

raise RuntimeError(
f'Callback not initialized correctly; self.topk not instantiated.'
)

def batch_end(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
try:
from megablocks.layers.moe import get_load_balancing_loss
except:
raise RuntimeError(
'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
)
tokens_per_expert, _ = zip(*get_load_balancing_loss())

tokens_per_expert = [
tpe.clone().detach() for tpe in tokens_per_expert
]
if self.all_reduce_stats:
for tpe in tokens_per_expert:
dist.all_reduce(tpe)

if self.normalize:
tokens_per_expert = [
tpe / (tpe.sum() / self.topk) for tpe in tokens_per_expert
]

all_tokens_per_expert = torch.concat(tokens_per_expert)

min_tpe = all_tokens_per_expert.min().item()
max_tpe = all_tokens_per_expert.max().item()
median_tpe = all_tokens_per_expert.median().item()
std_tpe = all_tokens_per_expert.float().std().item()

log_info = {
f'mb_moe/all_layers_min_tpe': min_tpe,
f'mb_moe/all_layers_max_tpe': max_tpe,
f'mb_moe/all_layers_median_tpe': median_tpe,
f'mb_moe/all_layers_std_tpe': std_tpe,
}

if self.log_every_layer:
for l_idx, tpe_layer in enumerate(tokens_per_expert):

min_tpe = tpe_layer.min().item()
max_tpe = tpe_layer.max().item()
median_tpe = tpe_layer.median().item()
std_tpe = tpe_layer.float().std().item()

log_info.update({
f'mb_moe/layer{l_idx}_min_tpe': min_tpe,
f'mb_moe/layer{l_idx}_max_tpe': max_tpe,
f'mb_moe/layer{l_idx}_median_tpe': median_tpe,
f'mb_moe/layer{l_idx}_std_tpe': std_tpe,
})

logger.log_metrics(log_info)
2 changes: 2 additions & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def build_finetuning_dataloader(cfg: DictConfig,
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
max_seq_len=cfg.dataset.max_seq_len,
allow_unsafe_types=cfg.dataset.get('allow_unsafe_types', False),
replication=cfg.dataset.get('replication', None),
)

else:
Expand Down
10 changes: 10 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,12 @@ class StreamingFinetuningDataset(StreamingDataset):
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
replication (int, optional): Determines how many consecutive devices will receive the same
samples. Useful for training with tensor or sequence parallelism, where multiple
devices need to see the same partition of the dataset. Defaults to ``None``.
"""

def __init__(self,
Expand All @@ -516,6 +522,8 @@ def __init__(self,
sampling_granularity: int = 1,
batching_method: str = 'random',
max_seq_len: int = 2048,
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
**kwargs: Any):

if len(kwargs) > 0:
Expand Down Expand Up @@ -552,6 +560,8 @@ def __init__(self,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
)

self.tokenizer = tokenizer
Expand Down
10 changes: 10 additions & 0 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class StreamingTextDataset(StreamingDataset):
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
replication (int, optional): Determines how many consecutive devices will receive the same
samples. Useful for training with tensor or sequence parallelism, where multiple
devices need to see the same partition of the dataset. Defaults to ``None``.
"""

def __init__(self,
Expand All @@ -109,6 +115,8 @@ def __init__(self,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
**kwargs: Any):

if len(kwargs) > 0:
Expand Down Expand Up @@ -151,6 +159,8 @@ def __init__(self,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand Down
Loading

0 comments on commit 073eb7e

Please sign in to comment.