Skip to content

Commit

Permalink
Merge branch 'main' into openai_compatible_gauntlet
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml authored Apr 10, 2024
2 parents c303b91 + 17f8aeb commit 78ac2d9
Show file tree
Hide file tree
Showing 56 changed files with 3,156 additions and 626 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Tutorial videos from the community:
Something missing? Contribute with a PR!

# Latest News
* [Blog: Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm)
* [Blog: LLM Training and Inference with Intel Gaudi2 AI Accelerators](https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators)
* [Blog: Training LLMs at Scale with AMD MI250 GPUs](https://www.databricks.com/blog/training-llms-scale-amd-mi250-gpus)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
Expand Down Expand Up @@ -305,7 +306,7 @@ dependencies = [
"llm-foundry",
]

[project.entry-points."llm_foundry.loggers"]
[project.entry-points."llmfoundry_loggers"]
my_logger = "foundry_registry.loggers:MyLogger"
```

Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import \
MegaBlocksMoE_TokPerExpert
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
Expand All @@ -34,6 +36,7 @@
callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
callbacks.register('oom_observer', func=OOMObserver)
callbacks.register('eval_output_logging', func=EvalOutputLogging)
callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert)

callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
Expand All @@ -46,6 +49,7 @@
'ScheduledGarbageCollector',
'EvalGauntlet',
'HuggingFaceCheckpointer',
'MegaBlocksMoE_TokPerExpert',
'AsyncEval',
'CurriculumLearning',
]
404 changes: 260 additions & 144 deletions llmfoundry/callbacks/hf_checkpointer.py

Large diffs are not rendered by default.

140 changes: 140 additions & 0 deletions llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Log tokens per expert for MegaBlocks MoE."""
from __future__ import annotations

import torch
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist


class MegaBlocksMoE_TokPerExpert(Callback):
"""Log tokens per expert for MegaBlocks MoE.
To compute the load balancing loss, MegaBlocks caches information including `tokens_per_expert`
(tpe). At the :attr:`.Event.BATCH_END` event this callback gets load_balancing_loss from
MegaBlocks to get `tokens_per_expert` then logs statistics (<STAT>) of the number of tokens
assigned to experts for each layer index (l_idx) under ``mb_moe/layer<l_idx>_<STAT>_tpe``.
The tokens_per_expert statistics are logged by the :class:`.Logger` to the following keys as
described below.
+----------------------------------+-----------------------------------------------------------+
| Key | Logged data |
+==================================+===========================================================+
| `mb_moe/alllayer_min_tpe` | Minimum tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_max_tpe` | Maximum tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_median_tpe` | Median tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_std_tpe` | Standard deviation of tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_min_tpe` | Minimum tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_max_tpe` | Maximum tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_median_tpe` | Median tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_std_tpe` | Standard deviation of tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
Args:
log_interval (int, optional): The interval on which to log (Default: 10).
log_every_layer (bool, optional): Enable logging ever layer's statisictics (True) or log
only aggregate statistics (Default: False).
all_reduce_stats (bool, optional): Enable aggregating statistics across gpus (True) or log
statistics for GPU 0 (Default: False).
normalize (bool, optional): Normalize token counts by total tokens (Default: True) or output
raw token count (False). When normalize is True, the callback displays the fraction of
unique tokens routed to each expert. When normalize is False, the callback displays the
total number of tokens routed to each expert.
"""

def __init__(
self,
log_interval: int = 10,
log_every_layer: bool = False,
all_reduce_stats: bool = False,
normalize: bool = True,
):
self.log_interval = log_interval
self.log_every_layer = log_every_layer
self.all_reduce_stats = all_reduce_stats
self.normalize = normalize

self.topk = None

def fit_start(self, state: State, logger: Logger) -> None:
if self.topk is None and self.normalize:
try:
from megablocks.layers.dmoe import dMoE
from megablocks.layers.moe import MoE
except:
raise RuntimeError(
'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
)
for module in state.model.modules():
if isinstance(module, (MoE, dMoE)):
self.topk = module.experts.args.moe_top_k
return

raise RuntimeError(
f'Callback not initialized correctly; self.topk not instantiated.'
)

def batch_end(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
try:
from megablocks.layers.moe import get_load_balancing_loss
except:
raise RuntimeError(
'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
)
tokens_per_expert, _ = zip(*get_load_balancing_loss())

tokens_per_expert = [
tpe.clone().detach() for tpe in tokens_per_expert
]
if self.all_reduce_stats:
for tpe in tokens_per_expert:
dist.all_reduce(tpe)

if self.normalize:
tokens_per_expert = [
tpe / (tpe.sum() / self.topk) for tpe in tokens_per_expert
]

all_tokens_per_expert = torch.concat(tokens_per_expert)

min_tpe = all_tokens_per_expert.min().item()
max_tpe = all_tokens_per_expert.max().item()
median_tpe = all_tokens_per_expert.median().item()
std_tpe = all_tokens_per_expert.float().std().item()

log_info = {
f'mb_moe/all_layers_min_tpe': min_tpe,
f'mb_moe/all_layers_max_tpe': max_tpe,
f'mb_moe/all_layers_median_tpe': median_tpe,
f'mb_moe/all_layers_std_tpe': std_tpe,
}

if self.log_every_layer:
for l_idx, tpe_layer in enumerate(tokens_per_expert):

min_tpe = tpe_layer.min().item()
max_tpe = tpe_layer.max().item()
median_tpe = tpe_layer.median().item()
std_tpe = tpe_layer.float().std().item()

log_info.update({
f'mb_moe/layer{l_idx}_min_tpe': min_tpe,
f'mb_moe/layer{l_idx}_max_tpe': max_tpe,
f'mb_moe/layer{l_idx}_median_tpe': median_tpe,
f'mb_moe/layer{l_idx}_std_tpe': std_tpe,
})

logger.log_metrics(log_info)
2 changes: 2 additions & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def build_finetuning_dataloader(cfg: DictConfig,
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
max_seq_len=cfg.dataset.max_seq_len,
allow_unsafe_types=cfg.dataset.get('allow_unsafe_types', False),
replication=cfg.dataset.get('replication', None),
)

else:
Expand Down
10 changes: 10 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,12 @@ class StreamingFinetuningDataset(StreamingDataset):
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
replication (int, optional): Determines how many consecutive devices will receive the same
samples. Useful for training with tensor or sequence parallelism, where multiple
devices need to see the same partition of the dataset. Defaults to ``None``.
"""

def __init__(self,
Expand All @@ -516,6 +522,8 @@ def __init__(self,
sampling_granularity: int = 1,
batching_method: str = 'random',
max_seq_len: int = 2048,
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
**kwargs: Any):

if len(kwargs) > 0:
Expand Down Expand Up @@ -552,6 +560,8 @@ def __init__(self,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
)

self.tokenizer = tokenizer
Expand Down
24 changes: 22 additions & 2 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
from composer.utils import dist
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -315,6 +316,8 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
"""
from composer.utils import dist, get_device, reproducibility

log.debug('Searching for optimal packing ratio.')

# Stash the rng state to restore later.
rng_state = reproducibility.get_rng_state()
# Set the seed so that auto packing is deterministic.
Expand Down Expand Up @@ -388,8 +391,19 @@ def profile_packing(
dataloader_cfg.persistent_workers = False

# If streaming dataset, use a temporary local folder for profiling
local_rank_zero = dist.get_global_rank() - dist.get_local_rank()
if dataloader_cfg.dataset.get('remote') is not None:
dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
dataloader_cfg.dataset.local = tmp_path

if dataloader_cfg.dataset.get('streams') is not None:
for stream_config in dataloader_cfg.dataset.streams.values():
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
stream_config.local = tmp_path

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
Expand Down Expand Up @@ -447,6 +461,12 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
waste_percent = 100 * packer.waste
return padding_percent, waste_percent

for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
log.debug('Profiling packing ratios')
total_packing_ratios = min(len(packing_ratios), len(raw_batch_sizes))
for i, (packing_ratio,
raw_batch_size) in enumerate(zip(packing_ratios, raw_batch_sizes)):
log.debug(
f'Progress [{i}/{total_packing_ratios}]: Profiling packing ratio {packing_ratio}'
)
padding, waste = profile(raw_batch_size)
yield (packing_ratio, padding, waste)
10 changes: 10 additions & 0 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class StreamingTextDataset(StreamingDataset):
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
replication (int, optional): Determines how many consecutive devices will receive the same
samples. Useful for training with tensor or sequence parallelism, where multiple
devices need to see the same partition of the dataset. Defaults to ``None``.
"""

def __init__(self,
Expand All @@ -109,6 +115,8 @@ def __init__(self,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
**kwargs: Any):

if len(kwargs) > 0:
Expand Down Expand Up @@ -151,6 +159,8 @@ def __init__(self,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand Down
20 changes: 20 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Type

import torch

from llmfoundry.utils.registry_utils import create_registry

# Layers
_norm_description = """The norms registry is used to register classes that implement normalization layers."""
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)

__all__ = [
'norms',
]
16 changes: 14 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from composer.models.huggingface import peft_installed
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import (AutoConfig, AutoModelForCausalLM, PretrainedConfig,
PreTrainedModel, PreTrainedTokenizerBase)

from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS,
DEFAULT_CAUSAL_LM_TRAIN_METRICS)
Expand Down Expand Up @@ -162,6 +162,18 @@ def _autoset_attn_implementation_monkeypatch(
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
elif isinstance(attr, PretrainedConfig):
if not isinstance(v, Mapping):
raise ValueError(
f'Expected a dictionary for config override {k}, but got {v}.'
)

for _k, _v in v.items():
if not hasattr(attr, _k):
raise ValueError(
f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).'
)
setattr(attr, _k, _v)
else:
setattr(config, k, v)

Expand Down
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm
from llmfoundry.models.layers.norm import LPLayerNorm

__all__ = [
'scaled_multihead_dot_product_attention',
Expand All @@ -23,7 +23,6 @@
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'NORM_CLASS_REGISTRY',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
Expand Down
Loading

0 comments on commit 78ac2d9

Please sign in to comment.