Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 9, 2024
2 parents ea77a27 + e17d04f commit 5584de5
Show file tree
Hide file tree
Showing 36 changed files with 733 additions and 444 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Tutorial videos from the community:
Something missing? Contribute with a PR!

# Latest News
* [Blog: Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm)
* [Blog: LLM Training and Inference with Intel Gaudi2 AI Accelerators](https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators)
* [Blog: Training LLMs at Scale with AMD MI250 GPUs](https://www.databricks.com/blog/training-llms-scale-amd-mi250-gpus)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
Expand Down Expand Up @@ -305,7 +306,7 @@ dependencies = [
"llm-foundry",
]

[project.entry-points."llm_foundry.loggers"]
[project.entry-points."llmfoundry_loggers"]
my_logger = "foundry_registry.loggers:MyLogger"
```

Expand Down
372 changes: 221 additions & 151 deletions llmfoundry/callbacks/hf_checkpointer.py

Large diffs are not rendered by default.

24 changes: 22 additions & 2 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
from composer.utils import dist
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -315,6 +316,8 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
"""
from composer.utils import dist, get_device, reproducibility

log.debug('Searching for optimal packing ratio.')

# Stash the rng state to restore later.
rng_state = reproducibility.get_rng_state()
# Set the seed so that auto packing is deterministic.
Expand Down Expand Up @@ -388,8 +391,19 @@ def profile_packing(
dataloader_cfg.persistent_workers = False

# If streaming dataset, use a temporary local folder for profiling
local_rank_zero = dist.get_global_rank() - dist.get_local_rank()
if dataloader_cfg.dataset.get('remote') is not None:
dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
dataloader_cfg.dataset.local = tmp_path

if dataloader_cfg.dataset.get('streams') is not None:
for stream_config in dataloader_cfg.dataset.streams.values():
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
stream_config.local = tmp_path

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
Expand Down Expand Up @@ -447,6 +461,12 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
waste_percent = 100 * packer.waste
return padding_percent, waste_percent

for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
log.debug('Profiling packing ratios')
total_packing_ratios = min(len(packing_ratios), len(raw_batch_sizes))
for i, (packing_ratio,
raw_batch_size) in enumerate(zip(packing_ratios, raw_batch_sizes)):
log.debug(
f'Progress [{i}/{total_packing_ratios}]: Profiling packing ratio {packing_ratio}'
)
padding, waste = profile(raw_batch_size)
yield (packing_ratio, padding, waste)
20 changes: 20 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Type

import torch

from llmfoundry.utils.registry_utils import create_registry

# Layers
_norm_description = """The norms registry is used to register classes that implement normalization layers."""
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)

__all__ = [
'norms',
]
16 changes: 14 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from composer.models.huggingface import peft_installed
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import (AutoConfig, AutoModelForCausalLM, PretrainedConfig,
PreTrainedModel, PreTrainedTokenizerBase)

from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS,
DEFAULT_CAUSAL_LM_TRAIN_METRICS)
Expand Down Expand Up @@ -161,6 +161,18 @@ def _autoset_attn_implementation_monkeypatch(
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
elif isinstance(attr, PretrainedConfig):
if not isinstance(v, Mapping):
raise ValueError(
f'Expected a dictionary for config override {k}, but got {v}.'
)

for _k, _v in v.items():
if not hasattr(attr, _k):
raise ValueError(
f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).'
)
setattr(attr, _k, _v)
else:
setattr(config, k, v)

Expand Down
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm
from llmfoundry.models.layers.norm import LPLayerNorm

__all__ = [
'scaled_multihead_dot_product_attention',
Expand All @@ -23,7 +23,6 @@
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'NORM_CLASS_REGISTRY',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
Expand Down
22 changes: 15 additions & 7 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -419,12 +419,19 @@ def __init__(
self.Wqkv._fused = (0, fuse_splits)

if self.qk_ln or self.qk_gn:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
norm_size = self.head_dim if qk_gn else d_model
self.q_ln = norm_class(norm_size, device=device)
self.q_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = norm_class(norm_size, device=device)
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
Expand Down Expand Up @@ -501,9 +508,10 @@ def forward(
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
if is_transformers_version_gte('4.38'):
(cos, sin) = rotary_emb(x=value,
position_ids=offset_info,
seq_len=None)
(cos, sin) = rotary_emb(
x=value,
position_ids=offset_info,
)
else:
(cos, sin) = rotary_emb(x=value, seq_len=seq_len)
if is_transformers_version_gte('4.38'):
Expand Down
15 changes: 11 additions & 4 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -72,7 +72,6 @@ def __init__(
del kwargs # unused, just to capture any extra args from the config
super().__init__()

norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

Expand All @@ -88,7 +87,11 @@ def __init__(
if k not in args_to_exclude_in_attn_class
}

self.norm_1 = norm_class(d_model, device=device)
self.norm_1 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
Expand All @@ -100,7 +103,11 @@ def __init__(
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
self.norm_2 = norm_class(d_model, device=device)
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)
self.ffn = build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
Expand Down
25 changes: 25 additions & 0 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.utils.registry_utils import construct_from_registry


def build_norm(
name: str,
normalized_shape: Union[int, List[int], torch.Size],
device: Optional[str] = None,
):
kwargs = {
'normalized_shape': normalized_shape,
'device': device,
}

return construct_from_registry(name=name,
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)
19 changes: 9 additions & 10 deletions llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Optional, Type, Union
from typing import List, Optional, Union

import torch

from llmfoundry.layers_registry import norms

norms.register(name='layernorm', func=torch.nn.LayerNorm)


def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
if torch.is_autocast_enabled():
Expand All @@ -18,6 +22,7 @@ def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
return tensor


@norms.register_class('low_precision_layernorm')
class LPLayerNorm(torch.nn.LayerNorm):

def __init__(
Expand Down Expand Up @@ -62,6 +67,7 @@ def rms_norm(x: torch.Tensor,
return output


@norms.register_class('rmsnorm')
class RMSNorm(torch.nn.Module):

def __init__(
Expand All @@ -84,6 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)


@norms.register_class('low_precision_rmsnorm')
class LPRMSNorm(RMSNorm):

def __init__(
Expand Down Expand Up @@ -111,6 +118,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.eps).to(dtype=x.dtype)


@norms.register_class('triton_rmsnorm')
class TritonRMSNorm(torch.nn.Module):

def __init__(
Expand Down Expand Up @@ -150,12 +158,3 @@ def forward(self, x: torch.Tensor):
prenorm=False,
residual_in_fp32=False,
)


NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
'layernorm': torch.nn.LayerNorm,
'low_precision_layernorm': LPLayerNorm,
'rmsnorm': RMSNorm,
'low_precision_rmsnorm': LPRMSNorm,
'triton_rmsnorm': TritonRMSNorm,
}
3 changes: 3 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
Expand Down
14 changes: 9 additions & 5 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from composer.utils import dist

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
Expand All @@ -42,11 +41,13 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.ffn import build_ffn as build_ffn
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.mpt.configuration_mpt import MPTConfig

# NOTE: All utils are imported directly even if unused so that
Expand Down Expand Up @@ -297,12 +298,11 @@ def __init__(self, config: MPTConfig):
else:
config.init_device = 'meta'

if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
if config.norm_type.lower() not in norms.get_all():
norm_options = ' | '.join(norms.get_all())
raise NotImplementedError(
f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).'
)
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]

# CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
# both report this helping with stabilizing training
Expand All @@ -329,7 +329,11 @@ def __init__(self, config: MPTConfig):
block.max_block_idx = config.n_layers - 1
pass_on_block_idx(block)

self.norm_f = norm_class(config.d_model, device=config.init_device)
self.norm_f = build_norm(
name=config.norm_type.lower(),
normalized_shape=config.d_model,
device=config.init_device,
)

self.rope = config.attn_config['rope']
self.rope_impl = None
Expand Down
Loading

0 comments on commit 5584de5

Please sign in to comment.